aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm14/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
blob: 4624b735bef8c3c2093b64262cbb02b66ddfd4c8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
//===- TruncInstCombine.cpp -----------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// TruncInstCombine - looks for expression dags post-dominated by TruncInst and
// for each eligible dag, it will create a reduced bit-width expression, replace
// the old expression with this new one and remove the old expression.
// Eligible expression dag is such that:
//   1. Contains only supported instructions.
//   2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.
//   3. Can be evaluated into type with reduced legal bit-width.
//   4. All instructions in the dag must not have users outside the dag.
//      The only exception is for {ZExt, SExt}Inst with operand type equal to
//      the new reduced type evaluated in (3).
//
// The motivation for this optimization is that evaluating and expression using
// smaller bit-width is preferable, especially for vectorization where we can
// fit more values in one vectorized instruction. In addition, this optimization
// may decrease the number of cast instructions, but will not increase it.
//
//===----------------------------------------------------------------------===//

#include "AggressiveInstCombineInternal.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instruction.h"
#include "llvm/Support/KnownBits.h"

using namespace llvm;

#define DEBUG_TYPE "aggressive-instcombine"

STATISTIC(
    NumDAGsReduced,
    "Number of truncations eliminated by reducing bit width of expression DAG");
STATISTIC(NumInstrsReduced,
          "Number of instructions whose bit width was reduced");

/// Given an instruction and a container, it fills all the relevant operands of
/// that instruction, with respect to the Trunc expression dag optimizaton.
static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {
  unsigned Opc = I->getOpcode();
  switch (Opc) {
  case Instruction::Trunc:
  case Instruction::ZExt:
  case Instruction::SExt:
    // These CastInst are considered leaves of the evaluated expression, thus,
    // their operands are not relevent.
    break;
  case Instruction::Add:
  case Instruction::Sub:
  case Instruction::Mul:
  case Instruction::And:
  case Instruction::Or:
  case Instruction::Xor:
  case Instruction::Shl:
  case Instruction::LShr:
  case Instruction::AShr:
  case Instruction::UDiv:
  case Instruction::URem:
  case Instruction::InsertElement:
    Ops.push_back(I->getOperand(0));
    Ops.push_back(I->getOperand(1));
    break;
  case Instruction::ExtractElement:
    Ops.push_back(I->getOperand(0));
    break;
  case Instruction::Select:
    Ops.push_back(I->getOperand(1));
    Ops.push_back(I->getOperand(2));
    break;
  default:
    llvm_unreachable("Unreachable!");
  }
}

bool TruncInstCombine::buildTruncExpressionDag() {
  SmallVector<Value *, 8> Worklist;
  SmallVector<Instruction *, 8> Stack;
  // Clear old expression dag.
  InstInfoMap.clear();

  Worklist.push_back(CurrentTruncInst->getOperand(0));

  while (!Worklist.empty()) {
    Value *Curr = Worklist.back();

    if (isa<Constant>(Curr)) {
      Worklist.pop_back();
      continue;
    }

    auto *I = dyn_cast<Instruction>(Curr);
    if (!I)
      return false;

    if (!Stack.empty() && Stack.back() == I) {
      // Already handled all instruction operands, can remove it from both the
      // Worklist and the Stack, and add it to the instruction info map.
      Worklist.pop_back();
      Stack.pop_back();
      // Insert I to the Info map.
      InstInfoMap.insert(std::make_pair(I, Info()));
      continue;
    }

    if (InstInfoMap.count(I)) {
      Worklist.pop_back();
      continue;
    }

    // Add the instruction to the stack before start handling its operands.
    Stack.push_back(I);

    unsigned Opc = I->getOpcode();
    switch (Opc) {
    case Instruction::Trunc:
    case Instruction::ZExt:
    case Instruction::SExt:
      // trunc(trunc(x)) -> trunc(x)
      // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest
      // trunc(ext(x)) -> trunc(x) if the source type is larger than the new
      // dest
      break;
    case Instruction::Add:
    case Instruction::Sub:
    case Instruction::Mul:
    case Instruction::And:
    case Instruction::Or:
    case Instruction::Xor:
    case Instruction::Shl:
    case Instruction::LShr:
    case Instruction::AShr:
    case Instruction::UDiv:
    case Instruction::URem:
    case Instruction::InsertElement:
    case Instruction::ExtractElement:
    case Instruction::Select: {
      SmallVector<Value *, 2> Operands;
      getRelevantOperands(I, Operands);
      append_range(Worklist, Operands);
      break;
    }
    default:
      // TODO: Can handle more cases here:
      // 1. shufflevector
      // 2. sdiv, srem
      // 3. phi node(and loop handling)
      // ...
      return false;
    }
  }
  return true;
}

unsigned TruncInstCombine::getMinBitWidth() {
  SmallVector<Value *, 8> Worklist;
  SmallVector<Instruction *, 8> Stack;

  Value *Src = CurrentTruncInst->getOperand(0);
  Type *DstTy = CurrentTruncInst->getType();
  unsigned TruncBitWidth = DstTy->getScalarSizeInBits();
  unsigned OrigBitWidth =
      CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();

  if (isa<Constant>(Src))
    return TruncBitWidth;

  Worklist.push_back(Src);
  InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth;

  while (!Worklist.empty()) {
    Value *Curr = Worklist.back();

    if (isa<Constant>(Curr)) {
      Worklist.pop_back();
      continue;
    }

    // Otherwise, it must be an instruction.
    auto *I = cast<Instruction>(Curr);

    auto &Info = InstInfoMap[I];

    SmallVector<Value *, 2> Operands;
    getRelevantOperands(I, Operands);

    if (!Stack.empty() && Stack.back() == I) {
      // Already handled all instruction operands, can remove it from both, the
      // Worklist and the Stack, and update MinBitWidth.
      Worklist.pop_back();
      Stack.pop_back();
      for (auto *Operand : Operands)
        if (auto *IOp = dyn_cast<Instruction>(Operand))
          Info.MinBitWidth =
              std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);
      continue;
    }

    // Add the instruction to the stack before start handling its operands.
    Stack.push_back(I);
    unsigned ValidBitWidth = Info.ValidBitWidth;

    // Update minimum bit-width before handling its operands. This is required
    // when the instruction is part of a loop.
    Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth);

    for (auto *Operand : Operands)
      if (auto *IOp = dyn_cast<Instruction>(Operand)) {
        // If we already calculated the minimum bit-width for this valid
        // bit-width, or for a smaller valid bit-width, then just keep the
        // answer we already calculated.
        unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;
        if (IOpBitwidth >= ValidBitWidth)
          continue;
        InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;
        Worklist.push_back(IOp);
      }
  }
  unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;
  assert(MinBitWidth >= TruncBitWidth);

  if (MinBitWidth > TruncBitWidth) {
    // In this case reducing expression with vector type might generate a new
    // vector type, which is not preferable as it might result in generating
    // sub-optimal code.
    if (DstTy->isVectorTy())
      return OrigBitWidth;
    // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).
    Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth);
    // Update minimum bit-width with the new destination type bit-width if
    // succeeded to find such, otherwise, with original bit-width.
    MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth;
  } else { // MinBitWidth == TruncBitWidth
    // In this case the expression can be evaluated with the trunc instruction
    // destination type, and trunc instruction can be omitted. However, we
    // should not perform the evaluation if the original type is a legal scalar
    // type and the target type is illegal.
    bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth);
    bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth);
    if (!DstTy->isVectorTy() && FromLegal && !ToLegal)
      return OrigBitWidth;
  }
  return MinBitWidth;
}

Type *TruncInstCombine::getBestTruncatedType() {
  if (!buildTruncExpressionDag())
    return nullptr;

  // We don't want to duplicate instructions, which isn't profitable. Thus, we
  // can't shrink something that has multiple users, unless all users are
  // post-dominated by the trunc instruction, i.e., were visited during the
  // expression evaluation.
  unsigned DesiredBitWidth = 0;
  for (auto Itr : InstInfoMap) {
    Instruction *I = Itr.first;
    if (I->hasOneUse())
      continue;
    bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I));
    for (auto *U : I->users())
      if (auto *UI = dyn_cast<Instruction>(U))
        if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {
          if (!IsExtInst)
            return nullptr;
          // If this is an extension from the dest type, we can eliminate it,
          // even if it has multiple users. Thus, update the DesiredBitWidth and
          // validate all extension instructions agrees on same DesiredBitWidth.
          unsigned ExtInstBitWidth =
              I->getOperand(0)->getType()->getScalarSizeInBits();
          if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)
            return nullptr;
          DesiredBitWidth = ExtInstBitWidth;
        }
  }

  unsigned OrigBitWidth =
      CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();

  // Initialize MinBitWidth for shift instructions with the minimum number
  // that is greater than shift amount (i.e. shift amount + 1).
  // For `lshr` adjust MinBitWidth so that all potentially truncated
  // bits of the value-to-be-shifted are zeros.
  // For `ashr` adjust MinBitWidth so that all potentially truncated
  // bits of the value-to-be-shifted are sign bits (all zeros or ones)
  // and even one (first) untruncated bit is sign bit.
  // Exit early if MinBitWidth is not less than original bitwidth.
  for (auto &Itr : InstInfoMap) {
    Instruction *I = Itr.first;
    if (I->isShift()) {
      KnownBits KnownRHS = computeKnownBits(I->getOperand(1));
      unsigned MinBitWidth = KnownRHS.getMaxValue()
                                 .uadd_sat(APInt(OrigBitWidth, 1))
                                 .getLimitedValue(OrigBitWidth);
      if (MinBitWidth == OrigBitWidth)
        return nullptr;
      if (I->getOpcode() == Instruction::LShr) {
        KnownBits KnownLHS = computeKnownBits(I->getOperand(0));
        MinBitWidth =
            std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());
      }
      if (I->getOpcode() == Instruction::AShr) {
        unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));
        MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);
      }
      if (MinBitWidth >= OrigBitWidth)
        return nullptr;
      Itr.second.MinBitWidth = MinBitWidth;
    }
    if (I->getOpcode() == Instruction::UDiv ||
        I->getOpcode() == Instruction::URem) {
      unsigned MinBitWidth = 0;
      for (const auto &Op : I->operands()) {
        KnownBits Known = computeKnownBits(Op);
        MinBitWidth =
            std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);
        if (MinBitWidth >= OrigBitWidth)
          return nullptr;
      }
      Itr.second.MinBitWidth = MinBitWidth;
    }
  }

  // Calculate minimum allowed bit-width allowed for shrinking the currently
  // visited truncate's operand.
  unsigned MinBitWidth = getMinBitWidth();

  // Check that we can shrink to smaller bit-width than original one and that
  // it is similar to the DesiredBitWidth is such exists.
  if (MinBitWidth >= OrigBitWidth ||
      (DesiredBitWidth && DesiredBitWidth != MinBitWidth))
    return nullptr;

  return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth);
}

/// Given a reduced scalar type \p Ty and a \p V value, return a reduced type
/// for \p V, according to its type, if it vector type, return the vector
/// version of \p Ty, otherwise return \p Ty.
static Type *getReducedType(Value *V, Type *Ty) {
  assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");
  if (auto *VTy = dyn_cast<VectorType>(V->getType()))
    return VectorType::get(Ty, VTy->getElementCount());
  return Ty;
}

Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {
  Type *Ty = getReducedType(V, SclTy);
  if (auto *C = dyn_cast<Constant>(V)) {
    C = ConstantExpr::getIntegerCast(C, Ty, false);
    // If we got a constantexpr back, try to simplify it with DL info.
    return ConstantFoldConstant(C, DL, &TLI);
  }

  auto *I = cast<Instruction>(V);
  Info Entry = InstInfoMap.lookup(I);
  assert(Entry.NewValue);
  return Entry.NewValue;
}

void TruncInstCombine::ReduceExpressionDag(Type *SclTy) {
  NumInstrsReduced += InstInfoMap.size();
  for (auto &Itr : InstInfoMap) { // Forward
    Instruction *I = Itr.first;
    TruncInstCombine::Info &NodeInfo = Itr.second;

    assert(!NodeInfo.NewValue && "Instruction has been evaluated");

    IRBuilder<> Builder(I);
    Value *Res = nullptr;
    unsigned Opc = I->getOpcode();
    switch (Opc) {
    case Instruction::Trunc:
    case Instruction::ZExt:
    case Instruction::SExt: {
      Type *Ty = getReducedType(I, SclTy);
      // If the source type of the cast is the type we're trying for then we can
      // just return the source.  There's no need to insert it because it is not
      // new.
      if (I->getOperand(0)->getType() == Ty) {
        assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst");
        NodeInfo.NewValue = I->getOperand(0);
        continue;
      }
      // Otherwise, must be the same type of cast, so just reinsert a new one.
      // This also handles the case of zext(trunc(x)) -> zext(x).
      Res = Builder.CreateIntCast(I->getOperand(0), Ty,
                                  Opc == Instruction::SExt);

      // Update Worklist entries with new value if needed.
      // There are three possible changes to the Worklist:
      // 1. Update Old-TruncInst -> New-TruncInst.
      // 2. Remove Old-TruncInst (if New node is not TruncInst).
      // 3. Add New-TruncInst (if Old node was not TruncInst).
      auto *Entry = find(Worklist, I);
      if (Entry != Worklist.end()) {
        if (auto *NewCI = dyn_cast<TruncInst>(Res))
          *Entry = NewCI;
        else
          Worklist.erase(Entry);
      } else if (auto *NewCI = dyn_cast<TruncInst>(Res))
          Worklist.push_back(NewCI);
      break;
    }
    case Instruction::Add:
    case Instruction::Sub:
    case Instruction::Mul:
    case Instruction::And:
    case Instruction::Or:
    case Instruction::Xor:
    case Instruction::Shl:
    case Instruction::LShr:
    case Instruction::AShr:
    case Instruction::UDiv:
    case Instruction::URem: {
      Value *LHS = getReducedOperand(I->getOperand(0), SclTy);
      Value *RHS = getReducedOperand(I->getOperand(1), SclTy);
      Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);
      // Preserve `exact` flag since truncation doesn't change exactness
      if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))
        if (auto *ResI = dyn_cast<Instruction>(Res))
          ResI->setIsExact(PEO->isExact());
      break;
    }
    case Instruction::ExtractElement: {
      Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
      Value *Idx = I->getOperand(1);
      Res = Builder.CreateExtractElement(Vec, Idx);
      break;
    }
    case Instruction::InsertElement: {
      Value *Vec = getReducedOperand(I->getOperand(0), SclTy);
      Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);
      Value *Idx = I->getOperand(2);
      Res = Builder.CreateInsertElement(Vec, NewElt, Idx);
      break;
    }
    case Instruction::Select: {
      Value *Op0 = I->getOperand(0);
      Value *LHS = getReducedOperand(I->getOperand(1), SclTy);
      Value *RHS = getReducedOperand(I->getOperand(2), SclTy);
      Res = Builder.CreateSelect(Op0, LHS, RHS);
      break;
    }
    default:
      llvm_unreachable("Unhandled instruction");
    }

    NodeInfo.NewValue = Res;
    if (auto *ResI = dyn_cast<Instruction>(Res))
      ResI->takeName(I);
  }

  Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);
  Type *DstTy = CurrentTruncInst->getType();
  if (Res->getType() != DstTy) {
    IRBuilder<> Builder(CurrentTruncInst);
    Res = Builder.CreateIntCast(Res, DstTy, false);
    if (auto *ResI = dyn_cast<Instruction>(Res))
      ResI->takeName(CurrentTruncInst);
  }
  CurrentTruncInst->replaceAllUsesWith(Res);

  // Erase old expression dag, which was replaced by the reduced expression dag.
  // We iterate backward, which means we visit the instruction before we visit
  // any of its operands, this way, when we get to the operand, we already
  // removed the instructions (from the expression dag) that uses it.
  CurrentTruncInst->eraseFromParent();
  for (auto &I : llvm::reverse(InstInfoMap)) {
    // We still need to check that the instruction has no users before we erase
    // it, because {SExt, ZExt}Inst Instruction might have other users that was
    // not reduced, in such case, we need to keep that instruction.
    if (I.first->use_empty())
      I.first->eraseFromParent();
  }
}

bool TruncInstCombine::run(Function &F) {
  bool MadeIRChange = false;

  // Collect all TruncInst in the function into the Worklist for evaluating.
  for (auto &BB : F) {
    // Ignore unreachable basic block.
    if (!DT.isReachableFromEntry(&BB))
      continue;
    for (auto &I : BB)
      if (auto *CI = dyn_cast<TruncInst>(&I))
        Worklist.push_back(CI);
  }

  // Process all TruncInst in the Worklist, for each instruction:
  //   1. Check if it dominates an eligible expression dag to be reduced.
  //   2. Create a reduced expression dag and replace the old one with it.
  while (!Worklist.empty()) {
    CurrentTruncInst = Worklist.pop_back_val();

    if (Type *NewDstSclTy = getBestTruncatedType()) {
      LLVM_DEBUG(
          dbgs() << "ICE: TruncInstCombine reducing type of expression dag "
                    "dominated by: "
                 << CurrentTruncInst << '\n');
      ReduceExpressionDag(NewDstSclTy);
      ++NumDAGsReduced;
      MadeIRChange = true;
    }
  }

  return MadeIRChange;
}