diff options
author | shadchin <shadchin@yandex-team.ru> | 2022-02-10 16:44:30 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:30 +0300 |
commit | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (patch) | |
tree | 012bb94d777798f1f56ac1cec429509766d05181 /contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp | |
parent | 6751af0b0c1b952fede40b19b71da8025b5d8bcf (diff) | |
download | ydb-2598ef1d0aee359b4b6d5fdd1758916d5907d04f.tar.gz |
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 1 of 2.
Diffstat (limited to 'contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp | 2890 |
1 files changed, 1445 insertions, 1445 deletions
diff --git a/contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp b/contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp index 1a9ae68573..5ff63868c9 100644 --- a/contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp +++ b/contrib/libs/llvm12/lib/Analysis/ScalarEvolution.cpp @@ -135,7 +135,7 @@ #include <vector> using namespace llvm; -using namespace PatternMatch; +using namespace PatternMatch; #define DEBUG_TYPE "scalar-evolution" @@ -227,11 +227,11 @@ ClassifyExpressions("scalar-evolution-classify-expressions", cl::Hidden, cl::init(true), cl::desc("When printing analysis, include information on every instruction")); -static cl::opt<bool> UseExpensiveRangeSharpening( - "scalar-evolution-use-expensive-range-sharpening", cl::Hidden, - cl::init(false), - cl::desc("Use more powerful methods of sharpening expression ranges. May " - "be costly in terms of compile time")); +static cl::opt<bool> UseExpensiveRangeSharpening( + "scalar-evolution-use-expensive-range-sharpening", cl::Hidden, + cl::init(false), + cl::desc("Use more powerful methods of sharpening expression ranges. May " + "be costly in terms of compile time")); //===----------------------------------------------------------------------===// // SCEV class definitions @@ -249,17 +249,17 @@ LLVM_DUMP_METHOD void SCEV::dump() const { #endif void SCEV::print(raw_ostream &OS) const { - switch (getSCEVType()) { + switch (getSCEVType()) { case scConstant: cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false); return; - case scPtrToInt: { - const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this); - const SCEV *Op = PtrToInt->getOperand(); - OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to " - << *PtrToInt->getType() << ")"; - return; - } + case scPtrToInt: { + const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this); + const SCEV *Op = PtrToInt->getOperand(); + OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to " + << *PtrToInt->getType() << ")"; + return; + } case scTruncate: { const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this); const SCEV *Op = Trunc->getOperand(); @@ -317,8 +317,8 @@ void SCEV::print(raw_ostream &OS) const { case scSMinExpr: OpStr = " smin "; break; - default: - llvm_unreachable("There are no other nary expression types."); + default: + llvm_unreachable("There are no other nary expression types."); } OS << "("; for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end(); @@ -335,10 +335,10 @@ void SCEV::print(raw_ostream &OS) const { OS << "<nuw>"; if (NAry->hasNoSignedWrap()) OS << "<nsw>"; - break; - default: - // Nothing to print for other nary expressions. - break; + break; + default: + // Nothing to print for other nary expressions. + break; } return; } @@ -380,10 +380,10 @@ void SCEV::print(raw_ostream &OS) const { } Type *SCEV::getType() const { - switch (getSCEVType()) { + switch (getSCEVType()) { case scConstant: return cast<SCEVConstant>(this)->getType(); - case scPtrToInt: + case scPtrToInt: case scTruncate: case scZeroExtend: case scSignExtend: @@ -465,42 +465,42 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) { return getConstant(ConstantInt::get(ITy, V, isSigned)); } -SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, - const SCEV *op, Type *ty) - : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) { - Operands[0] = op; -} - -SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, - Type *ITy) - : SCEVCastExpr(ID, scPtrToInt, Op, ITy) { - assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && - "Must be a non-bit-width-changing pointer-to-integer cast!"); -} - -SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, - SCEVTypes SCEVTy, const SCEV *op, - Type *ty) - : SCEVCastExpr(ID, SCEVTy, op, ty) {} - -SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, - Type *ty) - : SCEVIntegralCastExpr(ID, scTruncate, op, ty) { - assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && +SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, + const SCEV *op, Type *ty) + : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) { + Operands[0] = op; +} + +SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, + Type *ITy) + : SCEVCastExpr(ID, scPtrToInt, Op, ITy) { + assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() && + "Must be a non-bit-width-changing pointer-to-integer cast!"); +} + +SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, + SCEVTypes SCEVTy, const SCEV *op, + Type *ty) + : SCEVCastExpr(ID, SCEVTy, op, ty) {} + +SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, + Type *ty) + : SCEVIntegralCastExpr(ID, scTruncate, op, ty) { + assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate non-integer value!"); } SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) - : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) { - assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) { + assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot zero extend non-integer value!"); } SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) - : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) { - assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && + : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) { + assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot sign extend non-integer value!"); } @@ -699,7 +699,7 @@ static int CompareSCEVComplexity( return 0; // Primarily, sort the SCEVs by their getSCEVType(). - SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); + SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); if (LType != RType) return (int)LType - (int)RType; @@ -708,7 +708,7 @@ static int CompareSCEVComplexity( // Aside from the getSCEVType() ordering, the particular ordering // isn't very important except that it's beneficial to be consistent, // so that (a + b) and (b + a) don't end up as different expressions. - switch (LType) { + switch (LType) { case scUnknown: { const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); @@ -810,7 +810,7 @@ static int CompareSCEVComplexity( return X; } - case scPtrToInt: + case scPtrToInt: case scTruncate: case scZeroExtend: case scSignExtend: { @@ -1034,115 +1034,115 @@ const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, // SCEV Expression folder implementations //===----------------------------------------------------------------------===// -const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty, - unsigned Depth) { - assert(Ty->isIntegerTy() && "Target type must be an integer type!"); - assert(Depth <= 1 && "getPtrToIntExpr() should self-recurse at most once."); - - // We could be called with an integer-typed operands during SCEV rewrites. - // Since the operand is an integer already, just perform zext/trunc/self cast. - if (!Op->getType()->isPointerTy()) - return getTruncateOrZeroExtend(Op, Ty); - - // What would be an ID for such a SCEV cast expression? - FoldingSetNodeID ID; - ID.AddInteger(scPtrToInt); - ID.AddPointer(Op); - - void *IP = nullptr; - - // Is there already an expression for such a cast? - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) - return getTruncateOrZeroExtend(S, Ty); - - // If not, is this expression something we can't reduce any further? - if (isa<SCEVUnknown>(Op)) { - // Create an explicit cast node. - // We can reuse the existing insert position since if we get here, - // we won't have made any changes which would invalidate it. - Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType()); - assert(getDataLayout().getTypeSizeInBits(getEffectiveSCEVType( - Op->getType())) == getDataLayout().getTypeSizeInBits(IntPtrTy) && - "We can only model ptrtoint if SCEV's effective (integer) type is " - "sufficiently wide to represent all possible pointer values."); - SCEV *S = new (SCEVAllocator) - SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - return getTruncateOrZeroExtend(S, Ty); - } - - assert(Depth == 0 && - "getPtrToIntExpr() should not self-recurse for non-SCEVUnknown's."); - - // Otherwise, we've got some expression that is more complex than just a - // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an - // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown - // only, and the expressions must otherwise be integer-typed. - // So sink the cast down to the SCEVUnknown's. - - /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression, - /// which computes a pointer-typed value, and rewrites the whole expression - /// tree so that *all* the computations are done on integers, and the only - /// pointer-typed operands in the expression are SCEVUnknown. - class SCEVPtrToIntSinkingRewriter - : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> { - using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>; - - public: - SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} - - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) { - SCEVPtrToIntSinkingRewriter Rewriter(SE); - return Rewriter.visit(Scev); - } - - const SCEV *visit(const SCEV *S) { - Type *STy = S->getType(); - // If the expression is not pointer-typed, just keep it as-is. - if (!STy->isPointerTy()) - return S; - // Else, recursively sink the cast down into it. - return Base::visit(S); - } - - const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { - SmallVector<const SCEV *, 2> Operands; - bool Changed = false; - for (auto *Op : Expr->operands()) { - Operands.push_back(visit(Op)); - Changed |= Op != Operands.back(); - } - return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); - } - - const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { - SmallVector<const SCEV *, 2> Operands; - bool Changed = false; - for (auto *Op : Expr->operands()) { - Operands.push_back(visit(Op)); - Changed |= Op != Operands.back(); - } - return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); - } - - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - Type *ExprPtrTy = Expr->getType(); - assert(ExprPtrTy->isPointerTy() && - "Should only reach pointer-typed SCEVUnknown's."); - Type *ExprIntPtrTy = SE.getDataLayout().getIntPtrType(ExprPtrTy); - return SE.getPtrToIntExpr(Expr, ExprIntPtrTy, /*Depth=*/1); - } - }; - - // And actually perform the cast sinking. - const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); - assert(IntOp->getType()->isIntegerTy() && - "We must have succeeded in sinking the cast, " - "and ending up with an integer-typed expression!"); - return getTruncateOrZeroExtend(IntOp, Ty); -} - +const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty, + unsigned Depth) { + assert(Ty->isIntegerTy() && "Target type must be an integer type!"); + assert(Depth <= 1 && "getPtrToIntExpr() should self-recurse at most once."); + + // We could be called with an integer-typed operands during SCEV rewrites. + // Since the operand is an integer already, just perform zext/trunc/self cast. + if (!Op->getType()->isPointerTy()) + return getTruncateOrZeroExtend(Op, Ty); + + // What would be an ID for such a SCEV cast expression? + FoldingSetNodeID ID; + ID.AddInteger(scPtrToInt); + ID.AddPointer(Op); + + void *IP = nullptr; + + // Is there already an expression for such a cast? + if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) + return getTruncateOrZeroExtend(S, Ty); + + // If not, is this expression something we can't reduce any further? + if (isa<SCEVUnknown>(Op)) { + // Create an explicit cast node. + // We can reuse the existing insert position since if we get here, + // we won't have made any changes which would invalidate it. + Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType()); + assert(getDataLayout().getTypeSizeInBits(getEffectiveSCEVType( + Op->getType())) == getDataLayout().getTypeSizeInBits(IntPtrTy) && + "We can only model ptrtoint if SCEV's effective (integer) type is " + "sufficiently wide to represent all possible pointer values."); + SCEV *S = new (SCEVAllocator) + SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + return getTruncateOrZeroExtend(S, Ty); + } + + assert(Depth == 0 && + "getPtrToIntExpr() should not self-recurse for non-SCEVUnknown's."); + + // Otherwise, we've got some expression that is more complex than just a + // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an + // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown + // only, and the expressions must otherwise be integer-typed. + // So sink the cast down to the SCEVUnknown's. + + /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression, + /// which computes a pointer-typed value, and rewrites the whole expression + /// tree so that *all* the computations are done on integers, and the only + /// pointer-typed operands in the expression are SCEVUnknown. + class SCEVPtrToIntSinkingRewriter + : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> { + using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>; + + public: + SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {} + + static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) { + SCEVPtrToIntSinkingRewriter Rewriter(SE); + return Rewriter.visit(Scev); + } + + const SCEV *visit(const SCEV *S) { + Type *STy = S->getType(); + // If the expression is not pointer-typed, just keep it as-is. + if (!STy->isPointerTy()) + return S; + // Else, recursively sink the cast down into it. + return Base::visit(S); + } + + const SCEV *visitAddExpr(const SCEVAddExpr *Expr) { + SmallVector<const SCEV *, 2> Operands; + bool Changed = false; + for (auto *Op : Expr->operands()) { + Operands.push_back(visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags()); + } + + const SCEV *visitMulExpr(const SCEVMulExpr *Expr) { + SmallVector<const SCEV *, 2> Operands; + bool Changed = false; + for (auto *Op : Expr->operands()) { + Operands.push_back(visit(Op)); + Changed |= Op != Operands.back(); + } + return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags()); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + Type *ExprPtrTy = Expr->getType(); + assert(ExprPtrTy->isPointerTy() && + "Should only reach pointer-typed SCEVUnknown's."); + Type *ExprIntPtrTy = SE.getDataLayout().getIntPtrType(ExprPtrTy); + return SE.getPtrToIntExpr(Expr, ExprIntPtrTy, /*Depth=*/1); + } + }; + + // And actually perform the cast sinking. + const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this); + assert(IntOp->getType()->isIntegerTy() && + "We must have succeeded in sinking the cast, " + "and ending up with an integer-typed expression!"); + return getTruncateOrZeroExtend(IntOp, Ty); +} + const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) && @@ -1194,8 +1194,8 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; ++i) { const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1); - if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) && - isa<SCEVTruncateExpr>(S)) + if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) && + isa<SCEVTruncateExpr>(S)) numTruncs++; Operands.push_back(S); } @@ -1222,11 +1222,11 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty, return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap); } - // Return zero if truncating to known zeros. - uint32_t MinTrailingZeros = GetMinTrailingZeros(Op); - if (MinTrailingZeros >= getTypeSizeInBits(Ty)) - return getZero(Ty); - + // Return zero if truncating to known zeros. + uint32_t MinTrailingZeros = GetMinTrailingZeros(Op); + if (MinTrailingZeros >= getTypeSizeInBits(Ty)) + return getZero(Ty); + // The cast wasn't folded; create an explicit cast node. We can reuse // the existing insert position since if we get here, we won't have // made any changes which would invalidate it. @@ -1387,7 +1387,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty, // If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW // or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then // `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact. - SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType); + SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType); } return PreStart; } @@ -1591,7 +1591,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (!AR->hasNoUnsignedWrap()) { auto NewFlags = proveNoWrapViaConstantRanges(AR); - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); } // If we have special knowledge that this addrec won't overflow, @@ -1611,7 +1611,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // that value once it has finished. const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); if (!isa<SCEVCouldNotCompute>(MaxBECount)) { - // Manually compute the final value for AR, checking for overflow. + // Manually compute the final value for AR, checking for overflow. // Check whether the backedge-taken count can be losslessly casted to // the addrec's type. The count is always unsigned. @@ -1639,7 +1639,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { SCEV::FlagAnyWrap, Depth + 1); if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NUW, which is propagated to this AddRec. - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, @@ -1658,7 +1658,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (ZAdd == OperandExtendedAdd) { // Cache knowledge of AR NW, which is propagated to this AddRec. // Negative step causes unsigned wrap, but it still can't self-wrap. - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, @@ -1678,24 +1678,24 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // doing extra work that may not pay off. if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || !AC.assumptions().empty()) { - - auto NewFlags = proveNoUnsignedWrapViaInduction(AR); - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); - if (AR->hasNoUnsignedWrap()) { - // Same as nuw case above - duplicated here to avoid a compile time - // issue. It's not clear that the order of checks does matter, but - // it's one of two issue possible causes for a change which was - // reverted. Be conservative for the moment. - return getAddRecExpr( + + auto NewFlags = proveNoUnsignedWrapViaInduction(AR); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); + if (AR->hasNoUnsignedWrap()) { + // Same as nuw case above - duplicated here to avoid a compile time + // issue. It's not clear that the order of checks does matter, but + // it's one of two issue possible causes for a change which was + // reverted. Be conservative for the moment. + return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); - } - - // For a negative step, we can extend the operands iff doing so only - // traverses values in the range zext([0,UINT_MAX]). - if (isKnownNegative(Step)) { + } + + // For a negative step, we can extend the operands iff doing so only + // traverses values in the range zext([0,UINT_MAX]). + if (isKnownNegative(Step)) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - getSignedRangeMin(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || @@ -1703,7 +1703,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // Cache knowledge of AR NW, which is propagated to this // AddRec. Negative step causes unsigned wrap, but it // still can't self-wrap. - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, @@ -1732,7 +1732,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { } if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW); return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1), getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); @@ -1931,7 +1931,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { if (!AR->hasNoSignedWrap()) { auto NewFlags = proveNoWrapViaConstantRanges(AR); - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); } // If we have special knowledge that this addrec won't overflow, @@ -1980,7 +1980,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { SCEV::FlagAnyWrap, Depth + 1); if (SAdd == OperandExtendedAdd) { // Cache knowledge of AR NSW, which is propagated to this AddRec. - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); // Return the expression with the addrec on the outside. return getAddRecExpr( getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, @@ -2005,7 +2005,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { // Thus (AR is not NW => SAdd != OperandExtendedAdd) <=> // (SAdd == OperandExtendedAdd => AR is NW) - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -2017,16 +2017,16 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { } } - auto NewFlags = proveNoSignedWrapViaInduction(AR); - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); - if (AR->hasNoSignedWrap()) { - // Same as nsw case above - duplicated here to avoid a compile time - // issue. It's not clear that the order of checks does matter, but - // it's one of two issue possible causes for a change which was - // reverted. Be conservative for the moment. - return getAddRecExpr( - getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), - getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); + auto NewFlags = proveNoSignedWrapViaInduction(AR); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags); + if (AR->hasNoSignedWrap()) { + // Same as nsw case above - duplicated here to avoid a compile time + // issue. It's not clear that the order of checks does matter, but + // it's one of two issue possible causes for a change which was + // reverted. Be conservative for the moment. + return getAddRecExpr( + getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), + getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); } // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw> @@ -2047,7 +2047,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { } if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) { - setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); + setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW); return getAddRecExpr( getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1), getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags()); @@ -2177,7 +2177,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, } else { // A multiplication of a constant with some other value. Update // the map. - SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands())); + SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands())); const SCEV *Key = SE.getMulExpr(MulOps); auto Pair = M.insert({Key, NewScale}); if (Pair.second) { @@ -2281,9 +2281,9 @@ bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { /// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, - SCEV::NoWrapFlags OrigFlags, + SCEV::NoWrapFlags OrigFlags, unsigned Depth) { - assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && + assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty add!"); if (Ops.size() == 1) return Ops[0]; @@ -2319,20 +2319,20 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, if (Ops.size() == 1) return Ops[0]; } - // Delay expensive flag strengthening until necessary. - auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) { - return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags); - }; - + // Delay expensive flag strengthening until necessary. + auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) { + return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags); + }; + // Limit recursion calls depth. if (Depth > MaxArithDepth || hasHugeExpression(Ops)) - return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); + return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) { - // Don't strengthen flags if we have no new information. - SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S); - if (Add->getNoWrapFlags(OrigFlags) != OrigFlags) - Add->setNoWrapFlags(ComputeFlags(Ops)); + // Don't strengthen flags if we have no new information. + SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S); + if (Add->getNoWrapFlags(OrigFlags) != OrigFlags) + Add->setNoWrapFlags(ComputeFlags(Ops)); return S; } @@ -2358,7 +2358,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, FoundMatch = true; } if (FoundMatch) - return getAddExpr(Ops, OrigFlags, Depth + 1); + return getAddExpr(Ops, OrigFlags, Depth + 1); // Check for truncates. If all the operands are truncated from the same // type, see if factoring out the truncate would permit the result to be @@ -2593,16 +2593,16 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // If we found some loop invariants, fold them into the recurrence. if (!LIOps.empty()) { - // Compute nowrap flags for the addition of the loop-invariant ops and - // the addrec. Temporarily push it as an operand for that purpose. - LIOps.push_back(AddRec); - SCEV::NoWrapFlags Flags = ComputeFlags(LIOps); - LIOps.pop_back(); - + // Compute nowrap flags for the addition of the loop-invariant ops and + // the addrec. Temporarily push it as an operand for that purpose. + LIOps.push_back(AddRec); + SCEV::NoWrapFlags Flags = ComputeFlags(LIOps); + LIOps.pop_back(); + // NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step} LIOps.push_back(AddRec->getStart()); - SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands()); + SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands()); // This follows from the fact that the no-wrap flags on the outer add // expression are applicable on the 0th iteration, when the add recurrence // will be equal to its start value. @@ -2640,7 +2640,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, "AddRecExprs are not sorted in reverse dominance order?"); if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) { // Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L> - SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands()); + SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands()); for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); ++OtherIdx) { const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]); @@ -2671,7 +2671,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, it looks like we really DO need an add expr. Check to see if we // already have one, otherwise create a new one. - return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); + return getOrCreateAddExpr(Ops, ComputeFlags(Ops)); } const SCEV * @@ -2715,7 +2715,7 @@ ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops, UniqueSCEVs.InsertNode(S, IP); addToLoopUseLists(S); } - setNoWrapFlags(S, Flags); + setNoWrapFlags(S, Flags); return S; } @@ -2797,9 +2797,9 @@ static bool containsConstantInAddMulChain(const SCEV *StartExpr) { /// Get a canonical multiply expression, or something simpler if possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, - SCEV::NoWrapFlags OrigFlags, + SCEV::NoWrapFlags OrigFlags, unsigned Depth) { - assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && + assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) && "only nuw or nsw allowed"); assert(!Ops.empty() && "Cannot get empty mul!"); if (Ops.size() == 1) return Ops[0]; @@ -2813,52 +2813,52 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, &LI, DT); - // If there are any constants, fold them together. - unsigned Idx = 0; - if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { - ++Idx; - assert(Idx < Ops.size()); - while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { - // We found two constants, fold them together! - Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt()); - if (Ops.size() == 2) return Ops[0]; - Ops.erase(Ops.begin()+1); // Erase the folded element - LHSC = cast<SCEVConstant>(Ops[0]); - } - - // If we have a multiply of zero, it will always be zero. - if (LHSC->getValue()->isZero()) - return LHSC; - - // If we are left with a constant one being multiplied, strip it off. - if (LHSC->getValue()->isOne()) { - Ops.erase(Ops.begin()); - --Idx; - } - - if (Ops.size() == 1) - return Ops[0]; - } - - // Delay expensive flag strengthening until necessary. - auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) { - return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags); - }; - - // Limit recursion calls depth. - if (Depth > MaxArithDepth || hasHugeExpression(Ops)) - return getOrCreateMulExpr(Ops, ComputeFlags(Ops)); - + // If there are any constants, fold them together. + unsigned Idx = 0; + if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { + ++Idx; + assert(Idx < Ops.size()); + while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { + // We found two constants, fold them together! + Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt()); + if (Ops.size() == 2) return Ops[0]; + Ops.erase(Ops.begin()+1); // Erase the folded element + LHSC = cast<SCEVConstant>(Ops[0]); + } + + // If we have a multiply of zero, it will always be zero. + if (LHSC->getValue()->isZero()) + return LHSC; + + // If we are left with a constant one being multiplied, strip it off. + if (LHSC->getValue()->isOne()) { + Ops.erase(Ops.begin()); + --Idx; + } + + if (Ops.size() == 1) + return Ops[0]; + } + + // Delay expensive flag strengthening until necessary. + auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) { + return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags); + }; + + // Limit recursion calls depth. + if (Depth > MaxArithDepth || hasHugeExpression(Ops)) + return getOrCreateMulExpr(Ops, ComputeFlags(Ops)); + if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) { - // Don't strengthen flags if we have no new information. - SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S); - if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags) - Mul->setNoWrapFlags(ComputeFlags(Ops)); + // Don't strengthen flags if we have no new information. + SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S); + if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags) + Mul->setNoWrapFlags(ComputeFlags(Ops)); return S; } if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { - if (Ops.size() == 2) { + if (Ops.size() == 2) { // C1*(C2+V) -> C1*C2 + C1*V if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) // If any of Add's ops are Adds or Muls with a constant, apply this @@ -2874,9 +2874,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::FlagAnyWrap, Depth + 1), SCEV::FlagAnyWrap, Depth + 1); - if (Ops[0]->isAllOnesValue()) { - // If we have a mul by -1 of an add, try distributing the -1 among the - // add operands. + if (Ops[0]->isAllOnesValue()) { + // If we have a mul by -1 of an add, try distributing the -1 among the + // add operands. if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) { SmallVector<const SCEV *, 4> NewOps; bool AnyFolded = false; @@ -2961,9 +2961,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // // No self-wrap cannot be guaranteed after changing the step size, but // will be inferred if either NUW or NSW is true. - SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec}); - const SCEV *NewRec = getAddRecExpr( - NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags)); + SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec}); + const SCEV *NewRec = getAddRecExpr( + NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags)); // If all of the other operands were loop invariant, we are done. if (Ops.size() == 1) return NewRec; @@ -3056,7 +3056,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, it looks like we really DO need an mul expr. Check to see if we // already have one, otherwise create a new one. - return getOrCreateMulExpr(Ops, ComputeFlags(Ops)); + return getOrCreateMulExpr(Ops, ComputeFlags(Ops)); } /// Represents an unsigned remainder expression based on unsigned division. @@ -3180,7 +3180,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const SCEV *Op = M->getOperand(i); const SCEV *Div = getUDivExpr(Op, RHSC); if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) { - Operands = SmallVector<const SCEV *, 4>(M->operands()); + Operands = SmallVector<const SCEV *, 4>(M->operands()); Operands[i] = Div; return getMulExpr(Operands); } @@ -3274,7 +3274,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, // first element of the mulexpr. if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) { if (LHSCst == RHSCst) { - SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands())); + SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands())); return getMulExpr(Operands); } @@ -3364,7 +3364,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, ? (L->getLoopDepth() < NestedLoop->getLoopDepth()) : (!NestedLoop->contains(L) && DT.dominates(L->getHeader(), NestedLoop->getHeader()))) { - SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands()); + SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands()); Operands[0] = NestedAR->getStart(); // AddRecs require their operands be loop-invariant with respect to their // loops. Don't perform this transformation if it would break this @@ -3417,12 +3417,12 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, // flow and the no-overflow bits may not be valid for the expression in any // context. This can be fixed similarly to how these flags are handled for // adds. - SCEV::NoWrapFlags OffsetWrap = - GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + SCEV::NoWrapFlags OffsetWrap = + GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap; Type *CurTy = GEP->getType(); bool FirstIter = true; - SmallVector<const SCEV *, 4> Offsets; + SmallVector<const SCEV *, 4> Offsets; for (const SCEV *IndexExpr : IndexExprs) { // Compute the (potentially symbolic) offset in bytes for this index. if (StructType *STy = dyn_cast<StructType>(CurTy)) { @@ -3430,7 +3430,7 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue(); unsigned FieldNo = Index->getZExtValue(); const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo); - Offsets.push_back(FieldOffset); + Offsets.push_back(FieldOffset); // Update CurTy to the type of the field at Index. CurTy = STy->getTypeAtIndex(Index); @@ -3450,27 +3450,27 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP, IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy); // Multiply the index by the element size to compute the element offset. - const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); - Offsets.push_back(LocalOffset); + const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap); + Offsets.push_back(LocalOffset); } } - // Handle degenerate case of GEP without offsets. - if (Offsets.empty()) - return BaseExpr; - - // Add the offsets together, assuming nsw if inbounds. - const SCEV *Offset = getAddExpr(Offsets, OffsetWrap); - // Add the base address and the offset. We cannot use the nsw flag, as the - // base address is unsigned. However, if we know that the offset is - // non-negative, we can use nuw. - SCEV::NoWrapFlags BaseWrap = GEP->isInBounds() && isKnownNonNegative(Offset) - ? SCEV::FlagNUW : SCEV::FlagAnyWrap; - return getAddExpr(BaseExpr, Offset, BaseWrap); + // Handle degenerate case of GEP without offsets. + if (Offsets.empty()) + return BaseExpr; + + // Add the offsets together, assuming nsw if inbounds. + const SCEV *Offset = getAddExpr(Offsets, OffsetWrap); + // Add the base address and the offset. We cannot use the nsw flag, as the + // base address is unsigned. However, if we know that the offset is + // non-negative, we can use nuw. + SCEV::NoWrapFlags BaseWrap = GEP->isInBounds() && isKnownNonNegative(Offset) + ? SCEV::FlagNUW : SCEV::FlagAnyWrap; + return getAddExpr(BaseExpr, Offset, BaseWrap); } std::tuple<SCEV *, FoldingSetNodeID, void *> -ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType, +ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType, ArrayRef<const SCEV *> Ops) { FoldingSetNodeID ID; void *IP = nullptr; @@ -3481,17 +3481,17 @@ ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType, UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP); } -const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) { - SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap; - return getSMaxExpr(Op, getNegativeSCEV(Op, Flags)); -} - -const SCEV *ScalarEvolution::getSignumExpr(const SCEV *Op) { - Type *Ty = Op->getType(); - return getSMinExpr(getSMaxExpr(Op, getMinusOne(Ty)), getOne(Ty)); -} - -const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, +const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) { + SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap; + return getSMaxExpr(Op, getNegativeSCEV(Op, Flags)); +} + +const SCEV *ScalarEvolution::getSignumExpr(const SCEV *Op) { + Type *Ty = Op->getType(); + return getSMinExpr(getSMaxExpr(Op, getMinusOne(Ty)), getOne(Ty)); +} + +const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, SmallVectorImpl<const SCEV *> &Ops) { assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!"); if (Ops.size() == 1) return Ops[0]; @@ -3615,8 +3615,8 @@ const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind, return ExistingSCEV; const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); std::uninitialized_copy(Ops.begin(), Ops.end(), O); - SCEV *S = new (SCEVAllocator) - SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); + SCEV *S = new (SCEVAllocator) + SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); addToLoopUseLists(S); @@ -3661,42 +3661,42 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) { return getMinMaxExpr(scUMinExpr, Ops); } -const SCEV * -ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy, - ScalableVectorType *ScalableTy) { - Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo()); - Constant *One = ConstantInt::get(IntTy, 1); - Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One); - // Note that the expression we created is the final expression, we don't - // want to simplify it any further Also, if we call a normal getSCEV(), - // we'll end up in an endless recursion. So just create an SCEVUnknown. - return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy)); -} - +const SCEV * +ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy, + ScalableVectorType *ScalableTy) { + Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo()); + Constant *One = ConstantInt::get(IntTy, 1); + Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One); + // Note that the expression we created is the final expression, we don't + // want to simplify it any further Also, if we call a normal getSCEV(), + // we'll end up in an endless recursion. So just create an SCEVUnknown. + return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy)); +} + const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { - if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy)) - return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy); - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. + if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy)) + return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy); + // We can bypass creating a target-independent constant expression and then + // folding it back into a ConstantInt. This is just a compile-time + // optimization. return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy)); } -const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { - if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy)) - return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy); - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. - return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); -} - +const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) { + if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy)) + return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy); + // We can bypass creating a target-independent constant expression and then + // folding it back into a ConstantInt. This is just a compile-time + // optimization. + return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy)); +} + const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy, StructType *STy, unsigned FieldNo) { - // We can bypass creating a target-independent constant expression and then - // folding it back into a ConstantInt. This is just a compile-time - // optimization. + // We can bypass creating a target-independent constant expression and then + // folding it back into a ConstantInt. This is just a compile-time + // optimization. return getConstant( IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo)); } @@ -3920,7 +3920,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - return getMulExpr(V, getMinusOne(Ty), Flags); + return getMulExpr(V, getMinusOne(Ty), Flags); } /// If Expr computes ~A, return A else return nullptr @@ -3954,8 +3954,8 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return (const SCEV *)nullptr; MatchedOperands.push_back(Matched); } - return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()), - MatchedOperands); + return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()), + MatchedOperands); }; if (const SCEV *Replaced = MatchMinMaxNegation(MME)) return Replaced; @@ -3963,7 +3963,7 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { Type *Ty = V->getType(); Ty = getEffectiveSCEVType(Ty); - return getMinusSCEV(getMinusOne(Ty), V); + return getMinusSCEV(getMinusOne(Ty), V); } const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, @@ -4110,7 +4110,7 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( MaxType = getWiderType(MaxType, S->getType()); else MaxType = S->getType(); - assert(MaxType && "Failed to find maximum type!"); + assert(MaxType && "Failed to find maximum type!"); // Extend all ops to max type. SmallVector<const SCEV *, 2> PromotedOps; @@ -4127,7 +4127,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { return V; while (true) { - if (const SCEVIntegralCastExpr *Cast = dyn_cast<SCEVIntegralCastExpr>(V)) { + if (const SCEVIntegralCastExpr *Cast = dyn_cast<SCEVIntegralCastExpr>(V)) { V = Cast->getOperand(); } else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) { const SCEV *PtrOp = nullptr; @@ -4430,107 +4430,107 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { return Result; } -SCEV::NoWrapFlags -ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { - SCEV::NoWrapFlags Result = AR->getNoWrapFlags(); - - if (AR->hasNoSignedWrap()) - return Result; - - if (!AR->isAffine()) - return Result; - - const SCEV *Step = AR->getStepRecurrence(*this); - const Loop *L = AR->getLoop(); - - // Check whether the backedge-taken count is SCEVCouldNotCompute. - // Note that this serves two purposes: It filters out loops that are - // simply not analyzable, and it covers the case where this code is - // being called from within backedge-taken count analysis, such that - // attempting to ask for the backedge-taken count would likely result - // in infinite recursion. In the later case, the analysis code will - // cope with a conservative value, and it will take care to purge - // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); - - // Normally, in the cases we can prove no-overflow via a - // backedge guarding condition, we can also compute a backedge - // taken count for the loop. The exceptions are assumptions and - // guards present in the loop -- SCEV is not great at exploiting - // these to compute max backedge taken counts, but can still use - // these to prove lack of overflow. Use this fact to avoid - // doing extra work that may not pay off. - - if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && - AC.assumptions().empty()) - return Result; - - // If the backedge is guarded by a comparison with the pre-inc value the - // addrec is safe. Also, if the entry is guarded by a comparison with the - // start value and the backedge is guarded by a comparison with the post-inc - // value, the addrec is safe. - ICmpInst::Predicate Pred; - const SCEV *OverflowLimit = - getSignedOverflowLimitForStep(Step, &Pred, this); - if (OverflowLimit && - (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || - isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { - Result = setFlags(Result, SCEV::FlagNSW); - } - return Result; -} -SCEV::NoWrapFlags -ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { - SCEV::NoWrapFlags Result = AR->getNoWrapFlags(); - - if (AR->hasNoUnsignedWrap()) - return Result; - - if (!AR->isAffine()) - return Result; - - const SCEV *Step = AR->getStepRecurrence(*this); - unsigned BitWidth = getTypeSizeInBits(AR->getType()); - const Loop *L = AR->getLoop(); - - // Check whether the backedge-taken count is SCEVCouldNotCompute. - // Note that this serves two purposes: It filters out loops that are - // simply not analyzable, and it covers the case where this code is - // being called from within backedge-taken count analysis, such that - // attempting to ask for the backedge-taken count would likely result - // in infinite recursion. In the later case, the analysis code will - // cope with a conservative value, and it will take care to purge - // that value once it has finished. - const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); - - // Normally, in the cases we can prove no-overflow via a - // backedge guarding condition, we can also compute a backedge - // taken count for the loop. The exceptions are assumptions and - // guards present in the loop -- SCEV is not great at exploiting - // these to compute max backedge taken counts, but can still use - // these to prove lack of overflow. Use this fact to avoid - // doing extra work that may not pay off. - - if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && - AC.assumptions().empty()) - return Result; - - // If the backedge is guarded by a comparison with the pre-inc value the - // addrec is safe. Also, if the entry is guarded by a comparison with the - // start value and the backedge is guarded by a comparison with the post-inc - // value, the addrec is safe. - if (isKnownPositive(Step)) { - const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - - getUnsignedRangeMax(Step)); - if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || - isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { - Result = setFlags(Result, SCEV::FlagNUW); - } - } - - return Result; -} - +SCEV::NoWrapFlags +ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) { + SCEV::NoWrapFlags Result = AR->getNoWrapFlags(); + + if (AR->hasNoSignedWrap()) + return Result; + + if (!AR->isAffine()) + return Result; + + const SCEV *Step = AR->getStepRecurrence(*this); + const Loop *L = AR->getLoop(); + + // Check whether the backedge-taken count is SCEVCouldNotCompute. + // Note that this serves two purposes: It filters out loops that are + // simply not analyzable, and it covers the case where this code is + // being called from within backedge-taken count analysis, such that + // attempting to ask for the backedge-taken count would likely result + // in infinite recursion. In the later case, the analysis code will + // cope with a conservative value, and it will take care to purge + // that value once it has finished. + const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + + if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && + AC.assumptions().empty()) + return Result; + + // If the backedge is guarded by a comparison with the pre-inc value the + // addrec is safe. Also, if the entry is guarded by a comparison with the + // start value and the backedge is guarded by a comparison with the post-inc + // value, the addrec is safe. + ICmpInst::Predicate Pred; + const SCEV *OverflowLimit = + getSignedOverflowLimitForStep(Step, &Pred, this); + if (OverflowLimit && + (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || + isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { + Result = setFlags(Result, SCEV::FlagNSW); + } + return Result; +} +SCEV::NoWrapFlags +ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) { + SCEV::NoWrapFlags Result = AR->getNoWrapFlags(); + + if (AR->hasNoUnsignedWrap()) + return Result; + + if (!AR->isAffine()) + return Result; + + const SCEV *Step = AR->getStepRecurrence(*this); + unsigned BitWidth = getTypeSizeInBits(AR->getType()); + const Loop *L = AR->getLoop(); + + // Check whether the backedge-taken count is SCEVCouldNotCompute. + // Note that this serves two purposes: It filters out loops that are + // simply not analyzable, and it covers the case where this code is + // being called from within backedge-taken count analysis, such that + // attempting to ask for the backedge-taken count would likely result + // in infinite recursion. In the later case, the analysis code will + // cope with a conservative value, and it will take care to purge + // that value once it has finished. + const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L); + + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + + if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards && + AC.assumptions().empty()) + return Result; + + // If the backedge is guarded by a comparison with the pre-inc value the + // addrec is safe. Also, if the entry is guarded by a comparison with the + // start value and the backedge is guarded by a comparison with the post-inc + // value, the addrec is safe. + if (isKnownPositive(Step)) { + const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - + getUnsignedRangeMax(Step)); + if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || + isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { + Result = setFlags(Result, SCEV::FlagNUW); + } + } + + return Result; +} + namespace { /// Represents an abstract binary operation. This may exist as a @@ -4542,7 +4542,7 @@ struct BinaryOp { Value *RHS; bool IsNSW = false; bool IsNUW = false; - bool IsExact = false; + bool IsExact = false; /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or /// constant expression. @@ -4555,14 +4555,14 @@ struct BinaryOp { IsNSW = OBO->hasNoSignedWrap(); IsNUW = OBO->hasNoUnsignedWrap(); } - if (auto *PEO = dyn_cast<PossiblyExactOperator>(Op)) - IsExact = PEO->isExact(); + if (auto *PEO = dyn_cast<PossiblyExactOperator>(Op)) + IsExact = PEO->isExact(); } explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, - bool IsNUW = false, bool IsExact = false) - : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), - IsExact(IsExact) {} + bool IsNUW = false, bool IsExact = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), + IsExact(IsExact) {} }; } // end anonymous namespace @@ -5259,15 +5259,15 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, bool follow(const SCEV *S) { switch (S->getSCEVType()) { - case scConstant: - case scPtrToInt: - case scTruncate: - case scZeroExtend: - case scSignExtend: - case scAddExpr: - case scMulExpr: - case scUMaxExpr: - case scSMaxExpr: + case scConstant: + case scPtrToInt: + case scTruncate: + case scZeroExtend: + case scSignExtend: + case scAddExpr: + case scMulExpr: + case scUMaxExpr: + case scSMaxExpr: case scUMinExpr: case scSMinExpr: // These expressions are available if their operand(s) is/are. @@ -5305,7 +5305,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S, // We do not try to smart about these at all. return setUnavailable(); } - llvm_unreachable("Unknown SCEV kind!"); + llvm_unreachable("Unknown SCEV kind!"); } bool isDone() { return TraversalDone; } @@ -5525,9 +5525,9 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) return C->getAPInt().countTrailingZeros(); - if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S)) - return GetMinTrailingZeros(I->getOperand()); - + if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S)) + return GetMinTrailingZeros(I->getOperand()); + if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S)) return std::min(GetMinTrailingZeros(T->getOperand()), (uint32_t)getTypeSizeInBits(T->getType())); @@ -5619,15 +5619,15 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { return None; } -void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec, - SCEV::NoWrapFlags Flags) { - if (AddRec->getNoWrapFlags(Flags) != Flags) { - AddRec->setNoWrapFlags(Flags); - UnsignedRanges.erase(AddRec); - SignedRanges.erase(AddRec); - } -} - +void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec, + SCEV::NoWrapFlags Flags) { + if (AddRec->getNoWrapFlags(Flags) != Flags) { + AddRec->setNoWrapFlags(Flags); + UnsignedRanges.erase(AddRec); + SignedRanges.erase(AddRec); + } +} + /// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. @@ -5742,11 +5742,11 @@ ScalarEvolution::getRangeRef(const SCEV *S, RangeType)); } - if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) { - ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint); - return setRange(PtrToInt, SignHint, X); - } - + if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) { + ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint); + return setRange(PtrToInt, SignHint, X); + } + if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) { ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint); return setRange(Trunc, SignHint, @@ -5799,28 +5799,28 @@ ScalarEvolution::getRangeRef(const SCEV *S, auto RangeFromAffine = getRangeForAffineAR( AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, BitWidth); - ConservativeResult = - ConservativeResult.intersectWith(RangeFromAffine, RangeType); + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine, RangeType); auto RangeFromFactoring = getRangeViaFactoring( AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, BitWidth); - ConservativeResult = - ConservativeResult.intersectWith(RangeFromFactoring, RangeType); - } - - // Now try symbolic BE count and more powerful methods. - if (UseExpensiveRangeSharpening) { - const SCEV *SymbolicMaxBECount = - getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); - if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) && - getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && - AddRec->hasNoSelfWrap()) { - auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR( - AddRec, SymbolicMaxBECount, BitWidth, SignHint); + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring, RangeType); + } + + // Now try symbolic BE count and more powerful methods. + if (UseExpensiveRangeSharpening) { + const SCEV *SymbolicMaxBECount = + getSymbolicMaxBackedgeTakenCount(AddRec->getLoop()); + if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + AddRec->hasNoSelfWrap()) { + auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR( + AddRec, SymbolicMaxBECount, BitWidth, SignHint); ConservativeResult = - ConservativeResult.intersectWith(RangeFromAffineNew, RangeType); - } + ConservativeResult.intersectWith(RangeFromAffineNew, RangeType); + } } } @@ -5991,74 +5991,74 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, return SR.intersectWith(UR, ConstantRange::Smallest); } -ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( - const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth, - ScalarEvolution::RangeSignHint SignHint) { - assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n"); - assert(AddRec->hasNoSelfWrap() && - "This only works for non-self-wrapping AddRecs!"); - const bool IsSigned = SignHint == HINT_RANGE_SIGNED; - const SCEV *Step = AddRec->getStepRecurrence(*this); - // Only deal with constant step to save compile time. - if (!isa<SCEVConstant>(Step)) - return ConstantRange::getFull(BitWidth); - // Let's make sure that we can prove that we do not self-wrap during - // MaxBECount iterations. We need this because MaxBECount is a maximum - // iteration count estimate, and we might infer nw from some exit for which we - // do not know max exit count (or any other side reasoning). - // TODO: Turn into assert at some point. - if (getTypeSizeInBits(MaxBECount->getType()) > - getTypeSizeInBits(AddRec->getType())) - return ConstantRange::getFull(BitWidth); - MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType()); - const SCEV *RangeWidth = getMinusOne(AddRec->getType()); - const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); - const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); - if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount, - MaxItersWithoutWrap)) - return ConstantRange::getFull(BitWidth); - - ICmpInst::Predicate LEPred = - IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; - ICmpInst::Predicate GEPred = - IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; - const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); - - // We know that there is no self-wrap. Let's take Start and End values and - // look at all intermediate values V1, V2, ..., Vn that IndVar takes during - // the iteration. They either lie inside the range [Min(Start, End), - // Max(Start, End)] or outside it: - // - // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax; - // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax; - // - // No self wrap flag guarantees that the intermediate values cannot be BOTH - // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that - // knowledge, let's try to prove that we are dealing with Case 1. It is so if - // Start <= End and step is positive, or Start >= End and step is negative. - const SCEV *Start = AddRec->getStart(); - ConstantRange StartRange = getRangeRef(Start, SignHint); - ConstantRange EndRange = getRangeRef(End, SignHint); - ConstantRange RangeBetween = StartRange.unionWith(EndRange); - // If they already cover full iteration space, we will know nothing useful - // even if we prove what we want to prove. - if (RangeBetween.isFullSet()) - return RangeBetween; - // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax). - bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet() - : RangeBetween.isWrappedSet(); - if (IsWrappedSet) - return ConstantRange::getFull(BitWidth); - - if (isKnownPositive(Step) && - isKnownPredicateViaConstantRanges(LEPred, Start, End)) - return RangeBetween; - else if (isKnownNegative(Step) && - isKnownPredicateViaConstantRanges(GEPred, Start, End)) - return RangeBetween; - return ConstantRange::getFull(BitWidth); -} - +ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR( + const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth, + ScalarEvolution::RangeSignHint SignHint) { + assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n"); + assert(AddRec->hasNoSelfWrap() && + "This only works for non-self-wrapping AddRecs!"); + const bool IsSigned = SignHint == HINT_RANGE_SIGNED; + const SCEV *Step = AddRec->getStepRecurrence(*this); + // Only deal with constant step to save compile time. + if (!isa<SCEVConstant>(Step)) + return ConstantRange::getFull(BitWidth); + // Let's make sure that we can prove that we do not self-wrap during + // MaxBECount iterations. We need this because MaxBECount is a maximum + // iteration count estimate, and we might infer nw from some exit for which we + // do not know max exit count (or any other side reasoning). + // TODO: Turn into assert at some point. + if (getTypeSizeInBits(MaxBECount->getType()) > + getTypeSizeInBits(AddRec->getType())) + return ConstantRange::getFull(BitWidth); + MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType()); + const SCEV *RangeWidth = getMinusOne(AddRec->getType()); + const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step)); + const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs); + if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount, + MaxItersWithoutWrap)) + return ConstantRange::getFull(BitWidth); + + ICmpInst::Predicate LEPred = + IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; + ICmpInst::Predicate GEPred = + IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; + const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this); + + // We know that there is no self-wrap. Let's take Start and End values and + // look at all intermediate values V1, V2, ..., Vn that IndVar takes during + // the iteration. They either lie inside the range [Min(Start, End), + // Max(Start, End)] or outside it: + // + // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax; + // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax; + // + // No self wrap flag guarantees that the intermediate values cannot be BOTH + // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that + // knowledge, let's try to prove that we are dealing with Case 1. It is so if + // Start <= End and step is positive, or Start >= End and step is negative. + const SCEV *Start = AddRec->getStart(); + ConstantRange StartRange = getRangeRef(Start, SignHint); + ConstantRange EndRange = getRangeRef(End, SignHint); + ConstantRange RangeBetween = StartRange.unionWith(EndRange); + // If they already cover full iteration space, we will know nothing useful + // even if we prove what we want to prove. + if (RangeBetween.isFullSet()) + return RangeBetween; + // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax). + bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet() + : RangeBetween.isWrappedSet(); + if (IsWrappedSet) + return ConstantRange::getFull(BitWidth); + + if (isKnownPositive(Step) && + isKnownPredicateViaConstantRanges(LEPred, Start, End)) + return RangeBetween; + else if (isKnownNegative(Step) && + isKnownPredicateViaConstantRanges(GEPred, Start, End)) + return RangeBetween; + return ConstantRange::getFull(BitWidth); +} + ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, const SCEV *Step, const SCEV *MaxBECount, @@ -6091,7 +6091,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, } // Peel off a cast operation - if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) { + if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) { CastOp = SCast->getSCEVType(); S = SCast->getOperand(); } @@ -6292,7 +6292,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { const Instruction *Poison = PoisonStack.pop_back_val(); for (auto *PoisonUser : Poison->users()) { - if (propagatesPoison(cast<Operator>(PoisonUser))) { + if (propagatesPoison(cast<Operator>(PoisonUser))) { if (Pushed.insert(cast<Instruction>(PoisonUser)).second) PoisonStack.push_back(cast<Instruction>(PoisonUser)); } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { @@ -6356,7 +6356,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); else if (isa<ConstantPointerNull>(V)) - // FIXME: we shouldn't special-case null pointer constant. + // FIXME: we shouldn't special-case null pointer constant. return getZero(V->getType()); else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); @@ -6647,15 +6647,15 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } } } - if (BO->IsExact) { - // Given exact arithmetic in-bounds right-shift by a constant, - // we can lower it into: (abs(x) EXACT/u (1<<C)) * signum(x) - const SCEV *X = getSCEV(BO->LHS); - const SCEV *AbsX = getAbsExpr(X, /*IsNSW=*/false); - APInt Mult = APInt::getOneBitSet(BitWidth, AShrAmt); - const SCEV *Div = getUDivExactExpr(AbsX, getConstant(Mult)); - return getMulExpr(Div, getSignumExpr(X), SCEV::FlagNSW); - } + if (BO->IsExact) { + // Given exact arithmetic in-bounds right-shift by a constant, + // we can lower it into: (abs(x) EXACT/u (1<<C)) * signum(x) + const SCEV *X = getSCEV(BO->LHS); + const SCEV *AbsX = getAbsExpr(X, /*IsNSW=*/false); + APInt Mult = APInt::getOneBitSet(BitWidth, AShrAmt); + const SCEV *Div = getUDivExactExpr(AbsX, getConstant(Mult)); + return getMulExpr(Div, getSignumExpr(X), SCEV::FlagNSW); + } break; } } @@ -6692,29 +6692,29 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getSCEV(U->getOperand(0)); break; - case Instruction::PtrToInt: { - // Pointer to integer cast is straight-forward, so do model it. - Value *Ptr = U->getOperand(0); - const SCEV *Op = getSCEV(Ptr); - Type *DstIntTy = U->getType(); - // SCEV doesn't have constant pointer expression type, but it supports - // nullptr constant (and only that one), which is modelled in SCEV as a - // zero integer constant. So just skip the ptrtoint cast for constants. - if (isa<SCEVConstant>(Op)) - return getTruncateOrZeroExtend(Op, DstIntTy); - Type *PtrTy = Ptr->getType(); - Type *IntPtrTy = getDataLayout().getIntPtrType(PtrTy); - // But only if effective SCEV (integer) type is wide enough to represent - // all possible pointer values. - if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(PtrTy)) != - getDataLayout().getTypeSizeInBits(IntPtrTy)) - return getUnknown(V); - return getPtrToIntExpr(Op, DstIntTy); - } - case Instruction::IntToPtr: - // Just don't deal with inttoptr casts. - return getUnknown(V); - + case Instruction::PtrToInt: { + // Pointer to integer cast is straight-forward, so do model it. + Value *Ptr = U->getOperand(0); + const SCEV *Op = getSCEV(Ptr); + Type *DstIntTy = U->getType(); + // SCEV doesn't have constant pointer expression type, but it supports + // nullptr constant (and only that one), which is modelled in SCEV as a + // zero integer constant. So just skip the ptrtoint cast for constants. + if (isa<SCEVConstant>(Op)) + return getTruncateOrZeroExtend(Op, DstIntTy); + Type *PtrTy = Ptr->getType(); + Type *IntPtrTy = getDataLayout().getIntPtrType(PtrTy); + // But only if effective SCEV (integer) type is wide enough to represent + // all possible pointer values. + if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(PtrTy)) != + getDataLayout().getTypeSizeInBits(IntPtrTy)) + return getUnknown(V); + return getPtrToIntExpr(Op, DstIntTy); + } + case Instruction::IntToPtr: + // Just don't deal with inttoptr casts. + return getUnknown(V); + case Instruction::SDiv: // If both operands are non-negative, this is just an udiv. if (isKnownNonNegative(getSCEV(U->getOperand(0))) && @@ -6749,45 +6749,45 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { case Instruction::Invoke: if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand()) return getSCEV(RV); - - if (auto *II = dyn_cast<IntrinsicInst>(U)) { - switch (II->getIntrinsicID()) { - case Intrinsic::abs: - return getAbsExpr( - getSCEV(II->getArgOperand(0)), - /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne()); - case Intrinsic::umax: - return getUMaxExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); - case Intrinsic::umin: - return getUMinExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); - case Intrinsic::smax: - return getSMaxExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); - case Intrinsic::smin: - return getSMinExpr(getSCEV(II->getArgOperand(0)), - getSCEV(II->getArgOperand(1))); - case Intrinsic::usub_sat: { - const SCEV *X = getSCEV(II->getArgOperand(0)); - const SCEV *Y = getSCEV(II->getArgOperand(1)); - const SCEV *ClampedY = getUMinExpr(X, Y); - return getMinusSCEV(X, ClampedY, SCEV::FlagNUW); - } - case Intrinsic::uadd_sat: { - const SCEV *X = getSCEV(II->getArgOperand(0)); - const SCEV *Y = getSCEV(II->getArgOperand(1)); - const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y)); - return getAddExpr(ClampedX, Y, SCEV::FlagNUW); - } - case Intrinsic::start_loop_iterations: - // A start_loop_iterations is just equivalent to the first operand for - // SCEV purposes. - return getSCEV(II->getArgOperand(0)); - default: - break; - } - } + + if (auto *II = dyn_cast<IntrinsicInst>(U)) { + switch (II->getIntrinsicID()) { + case Intrinsic::abs: + return getAbsExpr( + getSCEV(II->getArgOperand(0)), + /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne()); + case Intrinsic::umax: + return getUMaxExpr(getSCEV(II->getArgOperand(0)), + getSCEV(II->getArgOperand(1))); + case Intrinsic::umin: + return getUMinExpr(getSCEV(II->getArgOperand(0)), + getSCEV(II->getArgOperand(1))); + case Intrinsic::smax: + return getSMaxExpr(getSCEV(II->getArgOperand(0)), + getSCEV(II->getArgOperand(1))); + case Intrinsic::smin: + return getSMinExpr(getSCEV(II->getArgOperand(0)), + getSCEV(II->getArgOperand(1))); + case Intrinsic::usub_sat: { + const SCEV *X = getSCEV(II->getArgOperand(0)); + const SCEV *Y = getSCEV(II->getArgOperand(1)); + const SCEV *ClampedY = getUMinExpr(X, Y); + return getMinusSCEV(X, ClampedY, SCEV::FlagNUW); + } + case Intrinsic::uadd_sat: { + const SCEV *X = getSCEV(II->getArgOperand(0)); + const SCEV *Y = getSCEV(II->getArgOperand(1)); + const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y)); + return getAddExpr(ClampedX, Y, SCEV::FlagNUW); + } + case Intrinsic::start_loop_iterations: + // A start_loop_iterations is just equivalent to the first operand for + // SCEV purposes. + return getSCEV(II->getArgOperand(0)); + default: + break; + } + } break; } @@ -6820,9 +6820,9 @@ unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) { return 0; } -unsigned -ScalarEvolution::getSmallConstantTripCount(const Loop *L, - const BasicBlock *ExitingBlock) { +unsigned +ScalarEvolution::getSmallConstantTripCount(const Loop *L, + const BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && "Exiting block must actually branch out of the loop!"); @@ -6859,7 +6859,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) { /// that control exits the loop via ExitingBlock. unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, - const BasicBlock *ExitingBlock) { + const BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); assert(L->isLoopExiting(ExitingBlock) && "Exiting block must actually branch out of the loop!"); @@ -6890,14 +6890,14 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L, } const SCEV *ScalarEvolution::getExitCount(const Loop *L, - const BasicBlock *ExitingBlock, + const BasicBlock *ExitingBlock, ExitCountKind Kind) { switch (Kind) { case Exact: - case SymbolicMaximum: + case SymbolicMaximum: return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); case ConstantMaximum: - return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this); + return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this); }; llvm_unreachable("Invalid ExitCountKind!"); } @@ -6914,15 +6914,15 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L, case Exact: return getBackedgeTakenInfo(L).getExact(L, this); case ConstantMaximum: - return getBackedgeTakenInfo(L).getConstantMax(this); - case SymbolicMaximum: - return getBackedgeTakenInfo(L).getSymbolicMax(L, this); + return getBackedgeTakenInfo(L).getConstantMax(this); + case SymbolicMaximum: + return getBackedgeTakenInfo(L).getSymbolicMax(L, this); }; llvm_unreachable("Invalid ExitCountKind!"); } bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { - return getBackedgeTakenInfo(L).isConstantMaxOrZero(this); + return getBackedgeTakenInfo(L).isConstantMaxOrZero(this); } /// Push PHI nodes in the header of the given loop onto the given Worklist. @@ -6952,7 +6952,7 @@ ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result); } -ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert an invalid entry for this loop. If the insertion // succeeds, proceed to actually compute a backedge-taken count and @@ -6976,11 +6976,11 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { const SCEV *BEExact = Result.getExact(L, this); if (BEExact != getCouldNotCompute()) { assert(isLoopInvariant(BEExact, L) && - isLoopInvariant(Result.getConstantMax(this), L) && + isLoopInvariant(Result.getConstantMax(this), L) && "Computed backedge-taken count isn't loop invariant for loop!"); ++NumTripCountsComputed; - } else if (Result.getConstantMax(this) == getCouldNotCompute() && - isa<PHINode>(L->getHeader()->begin())) { + } else if (Result.getConstantMax(this) == getCouldNotCompute() && + isa<PHINode>(L->getHeader()->begin())) { // Only count loops that have phi nodes as not being computable. ++NumTripCountsNotComputed; } @@ -7221,7 +7221,7 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE, /// Get the exact not taken count for this loop exit. const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock, +ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock, ScalarEvolution *SE) const { for (auto &ENT : ExitNotTaken) if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) @@ -7230,8 +7230,8 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock, return SE->getCouldNotCompute(); } -const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( - const BasicBlock *ExitingBlock, ScalarEvolution *SE) const { +const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( + const BasicBlock *ExitingBlock, ScalarEvolution *SE) const { for (auto &ENT : ExitNotTaken) if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate()) return ENT.MaxNotTaken; @@ -7239,32 +7239,32 @@ const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax( return SE->getCouldNotCompute(); } -/// getConstantMax - Get the constant max backedge taken count for the loop. +/// getConstantMax - Get the constant max backedge taken count for the loop. const SCEV * -ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const { +ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const { auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { return !ENT.hasAlwaysTruePredicate(); }; - if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getConstantMax()) + if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getConstantMax()) return SE->getCouldNotCompute(); - assert((isa<SCEVCouldNotCompute>(getConstantMax()) || - isa<SCEVConstant>(getConstantMax())) && + assert((isa<SCEVCouldNotCompute>(getConstantMax()) || + isa<SCEVConstant>(getConstantMax())) && "No point in having a non-constant max backedge taken count!"); - return getConstantMax(); -} - -const SCEV * -ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L, - ScalarEvolution *SE) { - if (!SymbolicMax) - SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L); - return SymbolicMax; -} - -bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( - ScalarEvolution *SE) const { + return getConstantMax(); +} + +const SCEV * +ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L, + ScalarEvolution *SE) { + if (!SymbolicMax) + SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L); + return SymbolicMax; +} + +bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( + ScalarEvolution *SE) const { auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { return !ENT.hasAlwaysTruePredicate(); }; @@ -7273,8 +7273,8 @@ bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero( bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, ScalarEvolution *SE) const { - if (getConstantMax() && getConstantMax() != SE->getCouldNotCompute() && - SE->hasOperand(getConstantMax(), S)) + if (getConstantMax() && getConstantMax() != SE->getCouldNotCompute() && + SE->hasOperand(getConstantMax(), S)) return true; for (auto &ENT : ExitNotTaken) @@ -7327,9 +7327,9 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M, /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( - ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts, - bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero) - : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) { + ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts, + bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero) + : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) { using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; ExitNotTaken.reserve(ExitCounts.size()); @@ -7349,8 +7349,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken, std::move(Predicate)); }); - assert((isa<SCEVCouldNotCompute>(ConstantMax) || - isa<SCEVConstant>(ConstantMax)) && + assert((isa<SCEVCouldNotCompute>(ConstantMax) || + isa<SCEVConstant>(ConstantMax)) && "No point in having a non-constant max backedge taken count!"); } @@ -7539,10 +7539,10 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates) { - // Handle BinOp conditions (And, Or). - if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( - Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) - return *LimitFromBinOp; + // Handle BinOp conditions (And, Or). + if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp( + Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) + return *LimitFromBinOp; // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. @@ -7574,95 +7574,95 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); } -Optional<ScalarEvolution::ExitLimit> -ScalarEvolution::computeExitLimitFromCondFromBinOp( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, - bool ControlsExit, bool AllowPredicates) { - // Check if the controlling expression for this loop is an And or Or. - Value *Op0, *Op1; - bool IsAnd = false; - if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) - IsAnd = true; - else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) - IsAnd = false; - else - return None; - - // EitherMayExit is true in these two cases: - // br (and Op0 Op1), loop, exit - // br (or Op0 Op1), exit, loop - bool EitherMayExit = IsAnd ^ ExitIfTrue; - ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue, - ControlsExit && !EitherMayExit, - AllowPredicates); - ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue, - ControlsExit && !EitherMayExit, - AllowPredicates); - - // Be robust against unsimplified IR for the form "op i1 X, NeutralElement" - const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd); - if (isa<ConstantInt>(Op1)) - return Op1 == NeutralElement ? EL0 : EL1; - if (isa<ConstantInt>(Op0)) - return Op0 == NeutralElement ? EL1 : EL0; - - const SCEV *BECount = getCouldNotCompute(); - const SCEV *MaxBECount = getCouldNotCompute(); - if (EitherMayExit) { - // Both conditions must be same for the loop to continue executing. - // Choose the less conservative count. - // If ExitCond is a short-circuit form (select), using - // umin(EL0.ExactNotTaken, EL1.ExactNotTaken) is unsafe in general. - // To see the detailed examples, please see - // test/Analysis/ScalarEvolution/exit-count-select.ll - bool PoisonSafe = isa<BinaryOperator>(ExitCond); - if (!PoisonSafe) - // Even if ExitCond is select, we can safely derive BECount using both - // EL0 and EL1 in these cases: - // (1) EL0.ExactNotTaken is non-zero - // (2) EL1.ExactNotTaken is non-poison - // (3) EL0.ExactNotTaken is zero (BECount should be simply zero and - // it cannot be umin(0, ..)) - // The PoisonSafe assignment below is simplified and the assertion after - // BECount calculation fully guarantees the condition (3). - PoisonSafe = isa<SCEVConstant>(EL0.ExactNotTaken) || - isa<SCEVConstant>(EL1.ExactNotTaken); - if (EL0.ExactNotTaken != getCouldNotCompute() && - EL1.ExactNotTaken != getCouldNotCompute() && PoisonSafe) { - BECount = - getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); - - // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form, - // it should have been simplified to zero (see the condition (3) above) - assert(!isa<BinaryOperator>(ExitCond) || !EL0.ExactNotTaken->isZero() || - BECount->isZero()); - } - if (EL0.MaxNotTaken == getCouldNotCompute()) - MaxBECount = EL1.MaxNotTaken; - else if (EL1.MaxNotTaken == getCouldNotCompute()) - MaxBECount = EL0.MaxNotTaken; - else - MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); - } else { - // Both conditions must be same at the same time for the loop to exit. - // For now, be conservative. - if (EL0.ExactNotTaken == EL1.ExactNotTaken) - BECount = EL0.ExactNotTaken; - } - - // There are cases (e.g. PR26207) where computeExitLimitFromCond is able - // to be more aggressive when computing BECount than when computing - // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and - // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken - // to not. - if (isa<SCEVCouldNotCompute>(MaxBECount) && - !isa<SCEVCouldNotCompute>(BECount)) - MaxBECount = getConstant(getUnsignedRangeMax(BECount)); - - return ExitLimit(BECount, MaxBECount, false, - { &EL0.Predicates, &EL1.Predicates }); -} - +Optional<ScalarEvolution::ExitLimit> +ScalarEvolution::computeExitLimitFromCondFromBinOp( + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + bool ControlsExit, bool AllowPredicates) { + // Check if the controlling expression for this loop is an And or Or. + Value *Op0, *Op1; + bool IsAnd = false; + if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1)))) + IsAnd = true; + else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1)))) + IsAnd = false; + else + return None; + + // EitherMayExit is true in these two cases: + // br (and Op0 Op1), loop, exit + // br (or Op0 Op1), exit, loop + bool EitherMayExit = IsAnd ^ ExitIfTrue; + ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue, + ControlsExit && !EitherMayExit, + AllowPredicates); + ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue, + ControlsExit && !EitherMayExit, + AllowPredicates); + + // Be robust against unsimplified IR for the form "op i1 X, NeutralElement" + const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd); + if (isa<ConstantInt>(Op1)) + return Op1 == NeutralElement ? EL0 : EL1; + if (isa<ConstantInt>(Op0)) + return Op0 == NeutralElement ? EL1 : EL0; + + const SCEV *BECount = getCouldNotCompute(); + const SCEV *MaxBECount = getCouldNotCompute(); + if (EitherMayExit) { + // Both conditions must be same for the loop to continue executing. + // Choose the less conservative count. + // If ExitCond is a short-circuit form (select), using + // umin(EL0.ExactNotTaken, EL1.ExactNotTaken) is unsafe in general. + // To see the detailed examples, please see + // test/Analysis/ScalarEvolution/exit-count-select.ll + bool PoisonSafe = isa<BinaryOperator>(ExitCond); + if (!PoisonSafe) + // Even if ExitCond is select, we can safely derive BECount using both + // EL0 and EL1 in these cases: + // (1) EL0.ExactNotTaken is non-zero + // (2) EL1.ExactNotTaken is non-poison + // (3) EL0.ExactNotTaken is zero (BECount should be simply zero and + // it cannot be umin(0, ..)) + // The PoisonSafe assignment below is simplified and the assertion after + // BECount calculation fully guarantees the condition (3). + PoisonSafe = isa<SCEVConstant>(EL0.ExactNotTaken) || + isa<SCEVConstant>(EL1.ExactNotTaken); + if (EL0.ExactNotTaken != getCouldNotCompute() && + EL1.ExactNotTaken != getCouldNotCompute() && PoisonSafe) { + BECount = + getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); + + // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form, + // it should have been simplified to zero (see the condition (3) above) + assert(!isa<BinaryOperator>(ExitCond) || !EL0.ExactNotTaken->isZero() || + BECount->isZero()); + } + if (EL0.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL1.MaxNotTaken; + else if (EL1.MaxNotTaken == getCouldNotCompute()) + MaxBECount = EL0.MaxNotTaken; + else + MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken); + } else { + // Both conditions must be same at the same time for the loop to exit. + // For now, be conservative. + if (EL0.ExactNotTaken == EL1.ExactNotTaken) + BECount = EL0.ExactNotTaken; + } + + // There are cases (e.g. PR26207) where computeExitLimitFromCond is able + // to be more aggressive when computing BECount than when computing + // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and + // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken + // to not. + if (isa<SCEVCouldNotCompute>(MaxBECount) && + !isa<SCEVCouldNotCompute>(BECount)) + MaxBECount = getConstant(getUnsignedRangeMax(BECount)); + + return ExitLimit(BECount, MaxBECount, false, + { &EL0.Predicates, &EL1.Predicates }); +} + ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, @@ -8357,110 +8357,110 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { /// SCEVConstant, because SCEVConstant is restricted to ConstantInt. /// Returns NULL if the SCEV isn't representable as a Constant. static Constant *BuildConstantFromSCEV(const SCEV *V) { - switch (V->getSCEVType()) { - case scCouldNotCompute: - case scAddRecExpr: - return nullptr; - case scConstant: - return cast<SCEVConstant>(V)->getValue(); - case scUnknown: - return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); - case scSignExtend: { - const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); - if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) - return ConstantExpr::getSExt(CastOp, SS->getType()); - return nullptr; - } - case scZeroExtend: { - const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); - if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) - return ConstantExpr::getZExt(CastOp, SZ->getType()); - return nullptr; - } - case scPtrToInt: { - const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V); - if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand())) - return ConstantExpr::getPtrToInt(CastOp, P2I->getType()); - - return nullptr; - } - case scTruncate: { - const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); - if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) - return ConstantExpr::getTrunc(CastOp, ST->getType()); - return nullptr; - } - case scAddExpr: { - const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); - if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { - if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { - unsigned AS = PTy->getAddressSpace(); - Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); - C = ConstantExpr::getBitCast(C, DestPtrTy); - } - for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { - Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); - if (!C2) - return nullptr; - - // First pointer! - if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { - unsigned AS = C2->getType()->getPointerAddressSpace(); - std::swap(C, C2); + switch (V->getSCEVType()) { + case scCouldNotCompute: + case scAddRecExpr: + return nullptr; + case scConstant: + return cast<SCEVConstant>(V)->getValue(); + case scUnknown: + return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue()); + case scSignExtend: { + const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand())) + return ConstantExpr::getSExt(CastOp, SS->getType()); + return nullptr; + } + case scZeroExtend: { + const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand())) + return ConstantExpr::getZExt(CastOp, SZ->getType()); + return nullptr; + } + case scPtrToInt: { + const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand())) + return ConstantExpr::getPtrToInt(CastOp, P2I->getType()); + + return nullptr; + } + case scTruncate: { + const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V); + if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand())) + return ConstantExpr::getTrunc(CastOp, ST->getType()); + return nullptr; + } + case scAddExpr: { + const SCEVAddExpr *SA = cast<SCEVAddExpr>(V); + if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) { + if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { + unsigned AS = PTy->getAddressSpace(); + Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); + C = ConstantExpr::getBitCast(C, DestPtrTy); + } + for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) { + Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i)); + if (!C2) + return nullptr; + + // First pointer! + if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) { + unsigned AS = C2->getType()->getPointerAddressSpace(); + std::swap(C, C2); Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS); - // The offsets have been converted to bytes. We can add bytes to an - // i8* by GEP with the byte count in the first index. + // The offsets have been converted to bytes. We can add bytes to an + // i8* by GEP with the byte count in the first index. C = ConstantExpr::getBitCast(C, DestPtrTy); } - // Don't bother trying to sum two pointers. We probably can't - // statically compute a load that results from it anyway. - if (C2->getType()->isPointerTy()) - return nullptr; - - if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { - if (PTy->getElementType()->isStructTy()) - C2 = ConstantExpr::getIntegerCast( - C2, Type::getInt32Ty(C->getContext()), true); - C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); - } else - C = ConstantExpr::getAdd(C, C2); + // Don't bother trying to sum two pointers. We probably can't + // statically compute a load that results from it anyway. + if (C2->getType()->isPointerTy()) + return nullptr; + + if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) { + if (PTy->getElementType()->isStructTy()) + C2 = ConstantExpr::getIntegerCast( + C2, Type::getInt32Ty(C->getContext()), true); + C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2); + } else + C = ConstantExpr::getAdd(C, C2); } - return C; - } - return nullptr; - } - case scMulExpr: { - const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); - if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { - // Don't bother with pointers at all. - if (C->getType()->isPointerTy()) - return nullptr; - for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { - Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); - if (!C2 || C2->getType()->isPointerTy()) - return nullptr; - C = ConstantExpr::getMul(C, C2); + return C; + } + return nullptr; + } + case scMulExpr: { + const SCEVMulExpr *SM = cast<SCEVMulExpr>(V); + if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) { + // Don't bother with pointers at all. + if (C->getType()->isPointerTy()) + return nullptr; + for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) { + Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i)); + if (!C2 || C2->getType()->isPointerTy()) + return nullptr; + C = ConstantExpr::getMul(C, C2); } - return C; - } - return nullptr; - } - case scUDivExpr: { - const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); - if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) - if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) - if (LHS->getType() == RHS->getType()) - return ConstantExpr::getUDiv(LHS, RHS); - return nullptr; - } - case scSMaxExpr: - case scUMaxExpr: - case scSMinExpr: - case scUMinExpr: - return nullptr; // TODO: smax, umax, smin, umax. - } - llvm_unreachable("Unknown SCEV kind!"); + return C; + } + return nullptr; + } + case scUDivExpr: { + const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V); + if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS())) + if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS())) + if (LHS->getType() == RHS->getType()) + return ConstantExpr::getUDiv(LHS, RHS); + return nullptr; + } + case scSMaxExpr: + case scUMaxExpr: + case scSMinExpr: + case scUMinExpr: + return nullptr; // TODO: smax, umax, smin, umax. + } + llvm_unreachable("Unknown SCEV kind!"); } const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { @@ -8471,22 +8471,22 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) { if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) { if (PHINode *PN = dyn_cast<PHINode>(I)) { - const Loop *CurrLoop = this->LI[I->getParent()]; + const Loop *CurrLoop = this->LI[I->getParent()]; // Looking for loop exit value. - if (CurrLoop && CurrLoop->getParentLoop() == L && - PN->getParent() == CurrLoop->getHeader()) { + if (CurrLoop && CurrLoop->getParentLoop() == L && + PN->getParent() == CurrLoop->getHeader()) { // Okay, there is no closed form solution for the PHI node. Check // to see if the loop that contains it has a known backedge-taken // count. If so, we may be able to force computation of the exit // value. - const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); + const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop); // This trivial case can show up in some degenerate cases where // the incoming IR has not yet been fully simplified. if (BackedgeTakenCount->isZero()) { Value *InitValue = nullptr; bool MultipleInitValues = false; for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { - if (!CurrLoop->contains(PN->getIncomingBlock(i))) { + if (!CurrLoop->contains(PN->getIncomingBlock(i))) { if (!InitValue) InitValue = PN->getIncomingValue(i); else if (InitValue != PN->getIncomingValue(i)) { @@ -8504,18 +8504,18 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { isKnownPositive(BackedgeTakenCount) && PN->getNumIncomingValues() == 2) { - unsigned InLoopPred = - CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1; + unsigned InLoopPred = + CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1; Value *BackedgeVal = PN->getIncomingValue(InLoopPred); - if (CurrLoop->isLoopInvariant(BackedgeVal)) + if (CurrLoop->isLoopInvariant(BackedgeVal)) return getSCEV(BackedgeVal); } if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) { // Okay, we know how many times the containing loop executes. If // this is a constant evolving PHI node, get the final value at // the specified iteration number. - Constant *RV = getConstantEvolutionLoopExitValue( - PN, BTCC->getAPInt(), CurrLoop); + Constant *RV = getConstantEvolutionLoopExitValue( + PN, BTCC->getAPInt(), CurrLoop); if (RV) return getSCEV(RV); } } @@ -8571,10 +8571,10 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { if (const CmpInst *CI = dyn_cast<CmpInst>(I)) C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0], Operands[1], DL, &TLI); - else if (const LoadInst *Load = dyn_cast<LoadInst>(I)) { - if (!Load->isVolatile()) - C = ConstantFoldLoadFromConstPtr(Operands[0], Load->getType(), - DL); + else if (const LoadInst *Load = dyn_cast<LoadInst>(I)) { + if (!Load->isVolatile()) + C = ConstantFoldLoadFromConstPtr(Operands[0], Load->getType(), + DL); } else C = ConstantFoldInstOperands(I, Operands, DL, &TLI); if (!C) return V; @@ -8691,13 +8691,13 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { return getTruncateExpr(Op, Cast->getType()); } - if (const SCEVPtrToIntExpr *Cast = dyn_cast<SCEVPtrToIntExpr>(V)) { - const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); - if (Op == Cast->getOperand()) - return Cast; // must be loop invariant - return getPtrToIntExpr(Op, Cast->getType()); - } - + if (const SCEVPtrToIntExpr *Cast = dyn_cast<SCEVPtrToIntExpr>(V)) { + const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L); + if (Op == Cast->getOperand()) + return Cast; // must be loop invariant + return getPtrToIntExpr(Op, Cast->getType()); + } + llvm_unreachable("Unknown SCEV type!"); } @@ -9112,10 +9112,10 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) { - APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L)); - APInt MaxBECountBase = getUnsignedRangeMax(Distance); - if (MaxBECountBase.ult(MaxBECount)) - MaxBECount = MaxBECountBase; + APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L)); + APInt MaxBECountBase = getUnsignedRangeMax(Distance); + if (MaxBECountBase.ult(MaxBECount)) + MaxBECount = MaxBECountBase; // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, // we end up with a loop whose backedge-taken count is n - 1. Detect this @@ -9180,19 +9180,19 @@ ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { return getCouldNotCompute(); } -std::pair<const BasicBlock *, const BasicBlock *> -ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) - const { +std::pair<const BasicBlock *, const BasicBlock *> +ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB) + const { // If the block has a unique predecessor, then there is no path from the // predecessor to the block that does not go through the direct edge // from the predecessor to the block. - if (const BasicBlock *Pred = BB->getSinglePredecessor()) + if (const BasicBlock *Pred = BB->getSinglePredecessor()) return {Pred, BB}; // A loop's header is defined to be a block that dominates the loop. // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. - if (const Loop *L = LI.getLoopFor(BB)) + if (const Loop *L = LI.getLoopFor(BB)) return {L->getLoopPredecessor(), L->getHeader()}; return {nullptr, nullptr}; @@ -9521,14 +9521,14 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS); } -bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const Instruction *Context) { - // TODO: Analyze guards and assumes from Context's block. - return isKnownPredicate(Pred, LHS, RHS) || - isBasicBlockEntryGuardedByCond(Context->getParent(), Pred, LHS, RHS); -} - +bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const Instruction *Context) { + // TODO: Analyze guards and assumes from Context's block. + return isKnownPredicate(Pred, LHS, RHS) || + isBasicBlockEntryGuardedByCond(Context->getParent(), Pred, LHS, RHS); +} + bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, const SCEVAddRecExpr *LHS, const SCEV *RHS) { @@ -9537,30 +9537,30 @@ bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); } -Optional<ScalarEvolution::MonotonicPredicateType> -ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS, - ICmpInst::Predicate Pred) { - auto Result = getMonotonicPredicateTypeImpl(LHS, Pred); +Optional<ScalarEvolution::MonotonicPredicateType> +ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred) { + auto Result = getMonotonicPredicateTypeImpl(LHS, Pred); #ifndef NDEBUG // Verify an invariant: inverting the predicate should turn a monotonically // increasing change to a monotonically decreasing one, and vice versa. - if (Result) { - auto ResultSwapped = - getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred)); + if (Result) { + auto ResultSwapped = + getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred)); - assert(ResultSwapped.hasValue() && "should be able to analyze both!"); - assert(ResultSwapped.getValue() != Result.getValue() && + assert(ResultSwapped.hasValue() && "should be able to analyze both!"); + assert(ResultSwapped.getValue() != Result.getValue() && "monotonicity should flip as we flip the predicate"); - } + } #endif return Result; } -Optional<ScalarEvolution::MonotonicPredicateType> -ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, - ICmpInst::Predicate Pred) { +Optional<ScalarEvolution::MonotonicPredicateType> +ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, + ICmpInst::Predicate Pred) { // A zero step value for LHS means the induction variable is essentially a // loop invariant value. We don't really depend on the predicate actually // flipping from false to true (for increasing predicates, and the other way @@ -9571,46 +9571,46 @@ ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS, // where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be // as general as possible. - // Only handle LE/LT/GE/GT predicates. - if (!ICmpInst::isRelational(Pred)) - return None; - - bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred); - assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) && - "Should be greater or less!"); + // Only handle LE/LT/GE/GT predicates. + if (!ICmpInst::isRelational(Pred)) + return None; - // Check that AR does not wrap. - if (ICmpInst::isUnsigned(Pred)) { + bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred); + assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) && + "Should be greater or less!"); + + // Check that AR does not wrap. + if (ICmpInst::isUnsigned(Pred)) { if (!LHS->hasNoUnsignedWrap()) - return None; - return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; - } else { - assert(ICmpInst::isSigned(Pred) && - "Relational predicate is either signed or unsigned!"); + return None; + return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; + } else { + assert(ICmpInst::isSigned(Pred) && + "Relational predicate is either signed or unsigned!"); if (!LHS->hasNoSignedWrap()) - return None; + return None; const SCEV *Step = LHS->getStepRecurrence(*this); - if (isKnownNonNegative(Step)) - return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; + if (isKnownNonNegative(Step)) + return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; - if (isKnownNonPositive(Step)) - return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; + if (isKnownNonPositive(Step)) + return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing; - return None; + return None; } } -Optional<ScalarEvolution::LoopInvariantPredicate> -ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS, - const Loop *L) { +Optional<ScalarEvolution::LoopInvariantPredicate> +ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const Loop *L) { // If there is a loop-invariant, force it into the RHS, otherwise bail out. if (!isLoopInvariant(RHS, L)) { if (!isLoopInvariant(LHS, L)) - return None; + return None; std::swap(LHS, RHS); Pred = ICmpInst::getSwappedPredicate(Pred); @@ -9618,11 +9618,11 @@ ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS); if (!ArLHS || ArLHS->getLoop() != L) - return None; + return None; - auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred); - if (!MonotonicType) - return None; + auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred); + if (!MonotonicType) + return None; // If the predicate "ArLHS `Pred` RHS" monotonically increases from false to // true as the loop iterates, and the backedge is control dependent on // "ArLHS `Pred` RHS" == true then we can reason as follows: @@ -9640,79 +9640,79 @@ ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred, // // A similar reasoning applies for a monotonically decreasing predicate, by // replacing true with false and false with true in the above two bullets. - bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing; + bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing; auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred); if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS)) - return None; - - return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), RHS); -} - -Optional<ScalarEvolution::LoopInvariantPredicate> -ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, - const Instruction *Context, const SCEV *MaxIter) { - // Try to prove the following set of facts: - // - The predicate is monotonic in the iteration space. - // - If the check does not fail on the 1st iteration: - // - No overflow will happen during first MaxIter iterations; - // - It will not fail on the MaxIter'th iteration. - // If the check does fail on the 1st iteration, we leave the loop and no - // other checks matter. - - // If there is a loop-invariant, force it into the RHS, otherwise bail out. - if (!isLoopInvariant(RHS, L)) { - if (!isLoopInvariant(LHS, L)) - return None; - - std::swap(LHS, RHS); - Pred = ICmpInst::getSwappedPredicate(Pred); - } - - auto *AR = dyn_cast<SCEVAddRecExpr>(LHS); - if (!AR || AR->getLoop() != L) - return None; - - // The predicate must be relational (i.e. <, <=, >=, >). - if (!ICmpInst::isRelational(Pred)) - return None; - - // TODO: Support steps other than +/- 1. - const SCEV *Step = AR->getStepRecurrence(*this); - auto *One = getOne(Step->getType()); - auto *MinusOne = getNegativeSCEV(One); - if (Step != One && Step != MinusOne) - return None; - - // Type mismatch here means that MaxIter is potentially larger than max - // unsigned value in start type, which mean we cannot prove no wrap for the - // indvar. - if (AR->getType() != MaxIter->getType()) - return None; - - // Value of IV on suggested last iteration. - const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this); - // Does it still meet the requirement? - if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS)) - return None; - // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does - // not exceed max unsigned value of this type), this effectively proves - // that there is no wrap during the iteration. To prove that there is no - // signed/unsigned wrap, we need to check that - // Start <= Last for step = 1 or Start >= Last for step = -1. - ICmpInst::Predicate NoOverflowPred = - CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; - if (Step == MinusOne) - NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred); - const SCEV *Start = AR->getStart(); - if (!isKnownPredicateAt(NoOverflowPred, Start, Last, Context)) - return None; - - // Everything is fine. - return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS); -} - + return None; + + return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), RHS); +} + +Optional<ScalarEvolution::LoopInvariantPredicate> +ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L, + const Instruction *Context, const SCEV *MaxIter) { + // Try to prove the following set of facts: + // - The predicate is monotonic in the iteration space. + // - If the check does not fail on the 1st iteration: + // - No overflow will happen during first MaxIter iterations; + // - It will not fail on the MaxIter'th iteration. + // If the check does fail on the 1st iteration, we leave the loop and no + // other checks matter. + + // If there is a loop-invariant, force it into the RHS, otherwise bail out. + if (!isLoopInvariant(RHS, L)) { + if (!isLoopInvariant(LHS, L)) + return None; + + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + auto *AR = dyn_cast<SCEVAddRecExpr>(LHS); + if (!AR || AR->getLoop() != L) + return None; + + // The predicate must be relational (i.e. <, <=, >=, >). + if (!ICmpInst::isRelational(Pred)) + return None; + + // TODO: Support steps other than +/- 1. + const SCEV *Step = AR->getStepRecurrence(*this); + auto *One = getOne(Step->getType()); + auto *MinusOne = getNegativeSCEV(One); + if (Step != One && Step != MinusOne) + return None; + + // Type mismatch here means that MaxIter is potentially larger than max + // unsigned value in start type, which mean we cannot prove no wrap for the + // indvar. + if (AR->getType() != MaxIter->getType()) + return None; + + // Value of IV on suggested last iteration. + const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this); + // Does it still meet the requirement? + if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS)) + return None; + // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does + // not exceed max unsigned value of this type), this effectively proves + // that there is no wrap during the iteration. To prove that there is no + // signed/unsigned wrap, we need to check that + // Start <= Last for step = 1 or Start >= Last for step = -1. + ICmpInst::Predicate NoOverflowPred = + CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; + if (Step == MinusOne) + NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred); + const SCEV *Start = AR->getStart(); + if (!isKnownPredicateAt(NoOverflowPred, Start, Last, Context)) + return None; + + // Everything is fine. + return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS); +} + bool ScalarEvolution::isKnownPredicateViaConstantRanges( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { if (HasSameValue(LHS, RHS)) @@ -9795,24 +9795,24 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative()) return true; break; - - case ICmpInst::ICMP_UGE: - std::swap(LHS, RHS); - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_ULE: - // X u<= (X + C)<nuw> for any C - if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW)) - return true; - break; - - case ICmpInst::ICMP_UGT: - std::swap(LHS, RHS); - LLVM_FALLTHROUGH; - case ICmpInst::ICMP_ULT: - // X u< (X + C)<nuw> if C != 0 - if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue()) - return true; - break; + + case ICmpInst::ICMP_UGE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULE: + // X u<= (X + C)<nuw> for any C + if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW)) + return true; + break; + + case ICmpInst::ICMP_UGT: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_ULT: + // X u< (X + C)<nuw> if C != 0 + if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue()) + return true; + break; } return false; @@ -9840,14 +9840,14 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } -bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, +bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // No need to even try if we know the module has no guards. if (!HasGuards) return false; - return any_of(*BB, [&](const Instruction &I) { + return any_of(*BB, [&](const Instruction &I) { using namespace llvm::PatternMatch; Value *Condition; @@ -9970,12 +9970,12 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, return false; } -bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, - ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { +bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, + ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { if (VerifyIR) - assert(!verifyFunction(*BB->getParent(), &dbgs()) && + assert(!verifyFunction(*BB->getParent(), &dbgs()) && "This cannot be done on broken IR!"); if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) @@ -10001,7 +10001,7 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, } // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard. - auto ProveViaGuard = [&](const BasicBlock *Block) { + auto ProveViaGuard = [&](const BasicBlock *Block) { if (isImpliedViaGuard(Block, Pred, LHS, RHS)) return true; if (ProvingStrictComparison) { @@ -10018,39 +10018,39 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, }; // Try to prove (Pred, LHS, RHS) using isImpliedCond. - auto ProveViaCond = [&](const Value *Condition, bool Inverse) { - const Instruction *Context = &BB->front(); - if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context)) + auto ProveViaCond = [&](const Value *Condition, bool Inverse) { + const Instruction *Context = &BB->front(); + if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context)) return true; if (ProvingStrictComparison) { if (!ProvedNonStrictComparison) - ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS, - Condition, Inverse, Context); + ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS, + Condition, Inverse, Context); if (!ProvedNonEquality) - ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, - Condition, Inverse, Context); + ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, + Condition, Inverse, Context); if (ProvedNonStrictComparison && ProvedNonEquality) return true; } return false; }; - // Starting at the block's predecessor, climb up the predecessor chain, as long + // Starting at the block's predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors - // leading to the original block. - const Loop *ContainingLoop = LI.getLoopFor(BB); - const BasicBlock *PredBB; - if (ContainingLoop && ContainingLoop->getHeader() == BB) - PredBB = ContainingLoop->getLoopPredecessor(); - else - PredBB = BB->getSinglePredecessor(); - for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB); - Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + // leading to the original block. + const Loop *ContainingLoop = LI.getLoopFor(BB); + const BasicBlock *PredBB; + if (ContainingLoop && ContainingLoop->getHeader() == BB) + PredBB = ContainingLoop->getLoopPredecessor(); + else + PredBB = BB->getSinglePredecessor(); + for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB); + Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { if (ProveViaGuard(Pair.first)) return true; - const BranchInst *LoopEntryPredicate = - dyn_cast<BranchInst>(Pair.first->getTerminator()); + const BranchInst *LoopEntryPredicate = + dyn_cast<BranchInst>(Pair.first->getTerminator()); if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional()) continue; @@ -10065,7 +10065,7 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, if (!AssumeVH) continue; auto *CI = cast<CallInst>(AssumeVH); - if (!DT.dominates(CI, BB)) + if (!DT.dominates(CI, BB)) continue; if (ProveViaCond(CI->getArgOperand(0), false)) @@ -10075,27 +10075,27 @@ bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB, return false; } -bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, - ICmpInst::Predicate Pred, - const SCEV *LHS, - const SCEV *RHS) { - // Interpret a null as meaning no loop, where there is obviously no guard - // (interprocedural conditions notwithstanding). - if (!L) - return false; - - // Both LHS and RHS must be available at loop entry. - assert(isAvailableAtLoopEntry(LHS, L) && - "LHS is not available at Loop Entry"); - assert(isAvailableAtLoopEntry(RHS, L) && - "RHS is not available at Loop Entry"); - return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS); -} - -bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, - const SCEV *RHS, - const Value *FoundCondValue, bool Inverse, - const Instruction *Context) { +bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, + ICmpInst::Predicate Pred, + const SCEV *LHS, + const SCEV *RHS) { + // Interpret a null as meaning no loop, where there is obviously no guard + // (interprocedural conditions notwithstanding). + if (!L) + return false; + + // Both LHS and RHS must be available at loop entry. + assert(isAvailableAtLoopEntry(LHS, L) && + "LHS is not available at Loop Entry"); + assert(isAvailableAtLoopEntry(RHS, L) && + "RHS is not available at Loop Entry"); + return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS); +} + +bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS, + const Value *FoundCondValue, bool Inverse, + const Instruction *Context) { if (!PendingLoopPredicates.insert(FoundCondValue).second) return false; @@ -10103,23 +10103,23 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); // Recursively handle And and Or conditions. - if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { + if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) { if (BO->getOpcode() == Instruction::And) { if (!Inverse) - return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, - Context) || - isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, - Context); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, + Context) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, + Context); } else if (BO->getOpcode() == Instruction::Or) { if (Inverse) - return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, - Context) || - isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, - Context); + return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse, + Context) || + isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse, + Context); } } - const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); + const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue); if (!ICI) return false; // Now that we found a conditional branch that dominates the loop or controls @@ -10133,36 +10133,36 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *FoundLHS = getSCEV(ICI->getOperand(0)); const SCEV *FoundRHS = getSCEV(ICI->getOperand(1)); - return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context); + return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context); } bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, ICmpInst::Predicate FoundPred, - const SCEV *FoundLHS, const SCEV *FoundRHS, - const Instruction *Context) { + const SCEV *FoundLHS, const SCEV *FoundRHS, + const Instruction *Context) { // Balance the types. if (getTypeSizeInBits(LHS->getType()) < getTypeSizeInBits(FoundLHS->getType())) { - // For unsigned and equality predicates, try to prove that both found - // operands fit into narrow unsigned range. If so, try to prove facts in - // narrow types. - if (!CmpInst::isSigned(FoundPred)) { - auto *NarrowType = LHS->getType(); - auto *WideType = FoundLHS->getType(); - auto BitWidth = getTypeSizeInBits(NarrowType); - const SCEV *MaxValue = getZeroExtendExpr( - getConstant(APInt::getMaxValue(BitWidth)), WideType); - if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) && - isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) { - const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); - const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); - if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS, - TruncFoundRHS, Context)) - return true; - } - } - + // For unsigned and equality predicates, try to prove that both found + // operands fit into narrow unsigned range. If so, try to prove facts in + // narrow types. + if (!CmpInst::isSigned(FoundPred)) { + auto *NarrowType = LHS->getType(); + auto *WideType = FoundLHS->getType(); + auto BitWidth = getTypeSizeInBits(NarrowType); + const SCEV *MaxValue = getZeroExtendExpr( + getConstant(APInt::getMaxValue(BitWidth)), WideType); + if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) && + isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) { + const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType); + const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType); + if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS, + TruncFoundRHS, Context)) + return true; + } + } + if (CmpInst::isSigned(Pred)) { LHS = getSignExtendExpr(LHS, FoundLHS->getType()); RHS = getSignExtendExpr(RHS, FoundLHS->getType()); @@ -10180,17 +10180,17 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType()); } } - return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS, - FoundRHS, Context); -} + return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS, + FoundRHS, Context); +} -bool ScalarEvolution::isImpliedCondBalancedTypes( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS, - const Instruction *Context) { - assert(getTypeSizeInBits(LHS->getType()) == - getTypeSizeInBits(FoundLHS->getType()) && - "Types should be balanced!"); +bool ScalarEvolution::isImpliedCondBalancedTypes( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS, + const Instruction *Context) { + assert(getTypeSizeInBits(LHS->getType()) == + getTypeSizeInBits(FoundLHS->getType()) && + "Types should be balanced!"); // Canonicalize the query to match the way instcombine will have // canonicalized the comparison. if (SimplifyICmpOperands(Pred, LHS, RHS)) @@ -10213,16 +10213,16 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( // Check whether the found predicate is the same as the desired predicate. if (FoundPred == Pred) - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); // Check whether swapping the found predicate makes it the same as the // desired predicate. if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) { if (isa<SCEVConstant>(RHS)) - return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context); + return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context); else - return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS, - LHS, FoundLHS, FoundRHS, Context); + return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS, + LHS, FoundLHS, FoundRHS, Context); } // Unsigned comparison is the same as signed comparison when both the operands @@ -10230,7 +10230,7 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( if (CmpInst::isUnsigned(FoundPred) && CmpInst::getSignedPredicate(FoundPred) == Pred && isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context); // Check if we can make progress by sharpening ranges. if (FoundPred == ICmpInst::ICMP_NE && @@ -10267,8 +10267,8 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( case ICmpInst::ICMP_UGE: // We know V `Pred` SharperMin. If this implies LHS `Pred` // RHS, we're done. - if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin), - Context)) + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin), + Context)) return true; LLVM_FALLTHROUGH; @@ -10283,26 +10283,26 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( // // If V `Pred` Min implies LHS `Pred` RHS, we're done. - if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), - Context)) - return true; - break; - - // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively. - case ICmpInst::ICMP_SLE: - case ICmpInst::ICMP_ULE: - if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, - LHS, V, getConstant(SharperMin), Context)) + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min), + Context)) return true; + break; + + // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively. + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, + LHS, V, getConstant(SharperMin), Context)) + return true; LLVM_FALLTHROUGH; - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_ULT: - if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, - LHS, V, getConstant(Min), Context)) - return true; - break; - + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS, + LHS, V, getConstant(Min), Context)) + return true; + break; + default: // No change break; @@ -10313,12 +10313,12 @@ bool ScalarEvolution::isImpliedCondBalancedTypes( // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) - if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context)) + if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context)) return true; if (Pred == ICmpInst::ICMP_NE) if (!ICmpInst::isTrueWhenEqual(FoundPred)) - if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, - Context)) + if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS, + Context)) return true; // Otherwise assume the worst. @@ -10397,51 +10397,51 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, return None; } -bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( - ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, - const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) { - // Try to recognize the following pattern: - // - // FoundRHS = ... - // ... - // loop: - // FoundLHS = {Start,+,W} - // context_bb: // Basic block from the same loop - // known(Pred, FoundLHS, FoundRHS) - // - // If some predicate is known in the context of a loop, it is also known on - // each iteration of this loop, including the first iteration. Therefore, in - // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to - // prove the original pred using this fact. - if (!Context) - return false; - const BasicBlock *ContextBB = Context->getParent(); - // Make sure AR varies in the context block. - if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) { - const Loop *L = AR->getLoop(); - // Make sure that context belongs to the loop and executes on 1st iteration - // (if it ever executes at all). - if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch())) - return false; - if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop())) - return false; - return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS); - } - - if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) { - const Loop *L = AR->getLoop(); - // Make sure that context belongs to the loop and executes on 1st iteration - // (if it ever executes at all). - if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch())) - return false; - if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop())) - return false; - return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart()); - } - - return false; -} - +bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) { + // Try to recognize the following pattern: + // + // FoundRHS = ... + // ... + // loop: + // FoundLHS = {Start,+,W} + // context_bb: // Basic block from the same loop + // known(Pred, FoundLHS, FoundRHS) + // + // If some predicate is known in the context of a loop, it is also known on + // each iteration of this loop, including the first iteration. Therefore, in + // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to + // prove the original pred using this fact. + if (!Context) + return false; + const BasicBlock *ContextBB = Context->getParent(); + // Make sure AR varies in the context block. + if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) { + const Loop *L = AR->getLoop(); + // Make sure that context belongs to the loop and executes on 1st iteration + // (if it ever executes at all). + if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch())) + return false; + if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop())) + return false; + return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS); + } + + if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) { + const Loop *L = AR->getLoop(); + // Make sure that context belongs to the loop and executes on 1st iteration + // (if it ever executes at all). + if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch())) + return false; + if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop())) + return false; + return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart()); + } + + return false; +} + bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { @@ -10622,10 +10622,10 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, if (!dominates(RHS, IncBB)) return false; const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); - // Make sure L does not refer to a value from a potentially previous - // iteration of a loop. - if (!properlyDominates(L, IncBB)) - return false; + // Make sure L does not refer to a value from a potentially previous + // iteration of a loop. + if (!properlyDominates(L, IncBB)) + return false; if (!ProvedEasily(L, RHS)) return false; } @@ -10636,18 +10636,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, - const SCEV *FoundRHS, - const Instruction *Context) { + const SCEV *FoundRHS, + const Instruction *Context) { if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS)) return true; - if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS, - Context)) - return true; - + if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS, + Context)) + return true; + return isImpliedCondOperandsHelper(Pred, LHS, RHS, FoundLHS, FoundRHS) || // ~x < ~y --> x > y @@ -10664,7 +10664,7 @@ static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr, if (!MinMaxExpr) return false; - return is_contained(MinMaxExpr->operands(), Candidate); + return is_contained(MinMaxExpr->operands(), Candidate); } static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, @@ -10746,31 +10746,31 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // We want to avoid hurting the compile time with analysis of too big trees. if (Depth > MaxSCEVOperationsImplicationDepth) return false; - - // We only want to work with GT comparison so far. - if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) { - Pred = CmpInst::getSwappedPredicate(Pred); + + // We only want to work with GT comparison so far. + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) { + Pred = CmpInst::getSwappedPredicate(Pred); std::swap(LHS, RHS); std::swap(FoundLHS, FoundRHS); } - - // For unsigned, try to reduce it to corresponding signed comparison. - if (Pred == ICmpInst::ICMP_UGT) - // We can replace unsigned predicate with its signed counterpart if all - // involved values are non-negative. - // TODO: We could have better support for unsigned. - if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) { - // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing - // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us - // use this fact to prove that LHS and RHS are non-negative. - const SCEV *MinusOne = getMinusOne(LHS->getType()); - if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, - FoundRHS) && - isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, - FoundRHS)) - Pred = ICmpInst::ICMP_SGT; - } - + + // For unsigned, try to reduce it to corresponding signed comparison. + if (Pred == ICmpInst::ICMP_UGT) + // We can replace unsigned predicate with its signed counterpart if all + // involved values are non-negative. + // TODO: We could have better support for unsigned. + if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) { + // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing + // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us + // use this fact to prove that LHS and RHS are non-negative. + const SCEV *MinusOne = getMinusOne(LHS->getType()); + if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS, + FoundRHS) && + isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS, + FoundRHS)) + Pred = ICmpInst::ICMP_SGT; + } + if (Pred != ICmpInst::ICMP_SGT) return false; @@ -10810,7 +10810,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, auto *LL = LHSAddExpr->getOperand(0); auto *LR = LHSAddExpr->getOperand(1); - auto *MinusOne = getMinusOne(RHS->getType()); + auto *MinusOne = getMinusOne(RHS->getType()); // Checks that S1 >= 0 && S2 > RHS, trivially or using the found context. auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) { @@ -10883,7 +10883,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, // 1. If FoundLHS is negative, then the result is 0. // 2. If FoundLHS is non-negative, then the result is non-negative. // Anyways, the result is non-negative. - auto *MinusOne = getMinusOne(WTy); + auto *MinusOne = getMinusOne(WTy); auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt); if (isKnownNegative(RHS) && IsSGTViaContext(FoundRHSExt, NegDenomMinusOne)) @@ -11238,13 +11238,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) BECount = BECountIfBackedgeTaken; else { - // If we know that RHS >= Start in the context of loop, then we know that - // max(RHS, Start) = RHS at this point. - if (isLoopEntryGuardedByCond( - L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, RHS, Start)) - End = RHS; - else - End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); + // If we know that RHS >= Start in the context of loop, then we know that + // max(RHS, Start) = RHS at this point. + if (isLoopEntryGuardedByCond( + L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, RHS, Start)) + End = RHS; + else + End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); } @@ -11311,15 +11311,15 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { - // If we know that Start >= RHS in the context of loop, then we know that - // min(RHS, Start) = RHS at this point. - if (isLoopEntryGuardedByCond( - L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS)) - End = RHS; - else - End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { + // If we know that Start >= RHS in the context of loop, then we know that + // min(RHS, Start) = RHS at this point. + if (isLoopEntryGuardedByCond( + L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS)) + End = RHS; + else + End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -11359,7 +11359,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, // If the start is a non-zero constant, shift the range to simplify things. if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart())) if (!SC->getValue()->isZero()) { - SmallVector<const SCEV *, 4> Operands(operands()); + SmallVector<const SCEV *, 4> Operands(operands()); Operands[0] = SE.getZero(SC->getType()); const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(), getNoWrapFlags(FlagNW)); @@ -11642,7 +11642,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, } // Remove all SCEVConstants. - erase_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }); + erase_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }); if (Terms.size() > 0) if (!findArrayDimensionsRec(SE, Terms, Sizes)) @@ -11970,7 +11970,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { // so that future queries will recompute the expressions using the new // value. Value *Old = getValPtr(); - SmallVector<User *, 16> Worklist(Old->users()); + SmallVector<User *, 16> Worklist(Old->users()); SmallPtrSet<User *, 8> Visited; while (!Worklist.empty()) { User *U = Worklist.pop_back_val(); @@ -11983,7 +11983,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { if (PHINode *PN = dyn_cast<PHINode>(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); SE->eraseValueFromMap(U); - llvm::append_range(Worklist, U->users()); + llvm::append_range(Worklist, U->users()); } // Delete the Old value. if (PHINode *PN = dyn_cast<PHINode>(Old)) @@ -12265,10 +12265,10 @@ ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) { ScalarEvolution::LoopDisposition ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { - switch (S->getSCEVType()) { + switch (S->getSCEVType()) { case scConstant: return LoopInvariant; - case scPtrToInt: + case scPtrToInt: case scTruncate: case scZeroExtend: case scSignExtend: @@ -12373,10 +12373,10 @@ ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) { ScalarEvolution::BlockDisposition ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) { - switch (S->getSCEVType()) { + switch (S->getSCEVType()) { case scConstant: return ProperlyDominatesBlock; - case scPtrToInt: + case scPtrToInt: case scTruncate: case scZeroExtend: case scSignExtend: @@ -12548,7 +12548,7 @@ void ScalarEvolution::verify() const { while (!LoopStack.empty()) { auto *L = LoopStack.pop_back_val(); - llvm::append_range(LoopStack, *L); + llvm::append_range(LoopStack, *L); auto *CurBECount = SCM.visit( const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L)); @@ -12592,25 +12592,25 @@ void ScalarEvolution::verify() const { std::abort(); } } - - // Collect all valid loops currently in LoopInfo. - SmallPtrSet<Loop *, 32> ValidLoops; - SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end()); - while (!Worklist.empty()) { - Loop *L = Worklist.pop_back_val(); - if (ValidLoops.contains(L)) - continue; - ValidLoops.insert(L); - Worklist.append(L->begin(), L->end()); - } - // Check for SCEV expressions referencing invalid/deleted loops. - for (auto &KV : ValueExprMap) { - auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second); - if (!AR) - continue; - assert(ValidLoops.contains(AR->getLoop()) && - "AddRec references invalid loop"); - } + + // Collect all valid loops currently in LoopInfo. + SmallPtrSet<Loop *, 32> ValidLoops; + SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end()); + while (!Worklist.empty()) { + Loop *L = Worklist.pop_back_val(); + if (ValidLoops.contains(L)) + continue; + ValidLoops.insert(L); + Worklist.append(L->begin(), L->end()); + } + // Check for SCEV expressions referencing invalid/deleted loops. + for (auto &KV : ValueExprMap) { + auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second); + if (!AR) + continue; + assert(ValidLoops.contains(AR->getLoop()) && + "AddRec references invalid loop"); + } } bool ScalarEvolution::invalidate( @@ -12643,11 +12643,11 @@ ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) { PreservedAnalyses ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { - // For compatibility with opt's -analyze feature under legacy pass manager - // which was not ported to NPM. This keeps tests using - // update_analyze_test_checks.py working. - OS << "Printing analysis 'Scalar Evolution Analysis' for function '" - << F.getName() << "':\n"; + // For compatibility with opt's -analyze feature under legacy pass manager + // which was not ported to NPM. This keeps tests using + // update_analyze_test_checks.py working. + OS << "Printing analysis 'Scalar Evolution Analysis' for function '" + << F.getName() << "':\n"; AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); return PreservedAnalyses::all(); } @@ -13143,24 +13143,24 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { } // Match the mathematical pattern A - (A / B) * B, where A and B can be -// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used -// for URem with constant power-of-2 second operands. +// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used +// for URem with constant power-of-2 second operands. // It's not always easy, as A and B can be folded (imagine A is X / 2, and B is // 4, A / B becomes X / 8). bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, const SCEV *&RHS) { - // Try to match 'zext (trunc A to iB) to iY', which is used - // for URem with constant power-of-2 second operands. Make sure the size of - // the operand A matches the size of the whole expressions. - if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr)) - if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) { - LHS = Trunc->getOperand(); - if (LHS->getType() != Expr->getType()) - LHS = getZeroExtendExpr(LHS, Expr->getType()); - RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) - << getTypeSizeInBits(Trunc->getType())); - return true; - } + // Try to match 'zext (trunc A to iB) to iY', which is used + // for URem with constant power-of-2 second operands. Make sure the size of + // the operand A matches the size of the whole expressions. + if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr)) + if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) { + LHS = Trunc->getOperand(); + if (LHS->getType() != Expr->getType()) + LHS = getZeroExtendExpr(LHS, Expr->getType()); + RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1) + << getTypeSizeInBits(Trunc->getType())); + return true; + } const auto *Add = dyn_cast<SCEVAddExpr>(Expr); if (Add == nullptr || Add->getNumOperands() != 2) return false; @@ -13194,146 +13194,146 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0))); return false; } - -const SCEV * -ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) { - SmallVector<BasicBlock*, 16> ExitingBlocks; - L->getExitingBlocks(ExitingBlocks); - - // Form an expression for the maximum exit count possible for this loop. We - // merge the max and exact information to approximate a version of - // getConstantMaxBackedgeTakenCount which isn't restricted to just constants. - SmallVector<const SCEV*, 4> ExitCounts; - for (BasicBlock *ExitingBB : ExitingBlocks) { - const SCEV *ExitCount = getExitCount(L, ExitingBB); - if (isa<SCEVCouldNotCompute>(ExitCount)) - ExitCount = getExitCount(L, ExitingBB, - ScalarEvolution::ConstantMaximum); - if (!isa<SCEVCouldNotCompute>(ExitCount)) { - assert(DT.dominates(ExitingBB, L->getLoopLatch()) && - "We should only have known counts for exiting blocks that " - "dominate latch!"); - ExitCounts.push_back(ExitCount); - } - } - if (ExitCounts.empty()) - return getCouldNotCompute(); - return getUMinFromMismatchedTypes(ExitCounts); -} - -/// This rewriter is similar to SCEVParameterRewriter (it replaces SCEVUnknown -/// components following the Map (Value -> SCEV)), but skips AddRecExpr because -/// we cannot guarantee that the replacement is loop invariant in the loop of -/// the AddRec. -class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { - ValueToSCEVMapTy ⤅ - -public: - SCEVLoopGuardRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) - : SCEVRewriteVisitor(SE), Map(M) {} - - const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; } - - const SCEV *visitUnknown(const SCEVUnknown *Expr) { - auto I = Map.find(Expr->getValue()); - if (I == Map.end()) - return Expr; - return I->second; - } -}; - -const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { - auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, - const SCEV *RHS, ValueToSCEVMapTy &RewriteMap) { - if (!isa<SCEVUnknown>(LHS)) { - std::swap(LHS, RHS); - Predicate = CmpInst::getSwappedPredicate(Predicate); - } - - // For now, limit to conditions that provide information about unknown - // expressions. - auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS); - if (!LHSUnknown) - return; - - // TODO: use information from more predicates. - switch (Predicate) { - case CmpInst::ICMP_ULT: { - if (!containsAddRecurrence(RHS)) { - const SCEV *Base = LHS; - auto I = RewriteMap.find(LHSUnknown->getValue()); - if (I != RewriteMap.end()) - Base = I->second; - - RewriteMap[LHSUnknown->getValue()] = - getUMinExpr(Base, getMinusSCEV(RHS, getOne(RHS->getType()))); - } - break; - } - case CmpInst::ICMP_ULE: { - if (!containsAddRecurrence(RHS)) { - const SCEV *Base = LHS; - auto I = RewriteMap.find(LHSUnknown->getValue()); - if (I != RewriteMap.end()) - Base = I->second; - RewriteMap[LHSUnknown->getValue()] = getUMinExpr(Base, RHS); - } - break; - } - case CmpInst::ICMP_EQ: - if (isa<SCEVConstant>(RHS)) - RewriteMap[LHSUnknown->getValue()] = RHS; - break; - case CmpInst::ICMP_NE: - if (isa<SCEVConstant>(RHS) && - cast<SCEVConstant>(RHS)->getValue()->isNullValue()) - RewriteMap[LHSUnknown->getValue()] = - getUMaxExpr(LHS, getOne(RHS->getType())); - break; - default: - break; - } - }; - // Starting at the loop predecessor, climb up the predecessor chain, as long - // as there are predecessors that can be found that have unique successors - // leading to the original header. - // TODO: share this logic with isLoopEntryGuardedByCond. - ValueToSCEVMapTy RewriteMap; - for (std::pair<const BasicBlock *, const BasicBlock *> Pair( - L->getLoopPredecessor(), L->getHeader()); - Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { - - const BranchInst *LoopEntryPredicate = - dyn_cast<BranchInst>(Pair.first->getTerminator()); - if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional()) - continue; - - // TODO: use information from more complex conditions, e.g. AND expressions. - auto *Cmp = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition()); - if (!Cmp) - continue; - - auto Predicate = Cmp->getPredicate(); - if (LoopEntryPredicate->getSuccessor(1) == Pair.second) - Predicate = CmpInst::getInversePredicate(Predicate); - CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)), - getSCEV(Cmp->getOperand(1)), RewriteMap); - } - - // Also collect information from assumptions dominating the loop. - for (auto &AssumeVH : AC.assumptions()) { - if (!AssumeVH) - continue; - auto *AssumeI = cast<CallInst>(AssumeVH); - auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0)); - if (!Cmp || !DT.dominates(AssumeI, L->getHeader())) - continue; - CollectCondition(Cmp->getPredicate(), getSCEV(Cmp->getOperand(0)), - getSCEV(Cmp->getOperand(1)), RewriteMap); - } - - if (RewriteMap.empty()) - return Expr; - SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); - return Rewriter.visit(Expr); -} + +const SCEV * +ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) { + SmallVector<BasicBlock*, 16> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // Form an expression for the maximum exit count possible for this loop. We + // merge the max and exact information to approximate a version of + // getConstantMaxBackedgeTakenCount which isn't restricted to just constants. + SmallVector<const SCEV*, 4> ExitCounts; + for (BasicBlock *ExitingBB : ExitingBlocks) { + const SCEV *ExitCount = getExitCount(L, ExitingBB); + if (isa<SCEVCouldNotCompute>(ExitCount)) + ExitCount = getExitCount(L, ExitingBB, + ScalarEvolution::ConstantMaximum); + if (!isa<SCEVCouldNotCompute>(ExitCount)) { + assert(DT.dominates(ExitingBB, L->getLoopLatch()) && + "We should only have known counts for exiting blocks that " + "dominate latch!"); + ExitCounts.push_back(ExitCount); + } + } + if (ExitCounts.empty()) + return getCouldNotCompute(); + return getUMinFromMismatchedTypes(ExitCounts); +} + +/// This rewriter is similar to SCEVParameterRewriter (it replaces SCEVUnknown +/// components following the Map (Value -> SCEV)), but skips AddRecExpr because +/// we cannot guarantee that the replacement is loop invariant in the loop of +/// the AddRec. +class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> { + ValueToSCEVMapTy ⤅ + +public: + SCEVLoopGuardRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M) + : SCEVRewriteVisitor(SE), Map(M) {} + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + auto I = Map.find(Expr->getValue()); + if (I == Map.end()) + return Expr; + return I->second; + } +}; + +const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) { + auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, + const SCEV *RHS, ValueToSCEVMapTy &RewriteMap) { + if (!isa<SCEVUnknown>(LHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + + // For now, limit to conditions that provide information about unknown + // expressions. + auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS); + if (!LHSUnknown) + return; + + // TODO: use information from more predicates. + switch (Predicate) { + case CmpInst::ICMP_ULT: { + if (!containsAddRecurrence(RHS)) { + const SCEV *Base = LHS; + auto I = RewriteMap.find(LHSUnknown->getValue()); + if (I != RewriteMap.end()) + Base = I->second; + + RewriteMap[LHSUnknown->getValue()] = + getUMinExpr(Base, getMinusSCEV(RHS, getOne(RHS->getType()))); + } + break; + } + case CmpInst::ICMP_ULE: { + if (!containsAddRecurrence(RHS)) { + const SCEV *Base = LHS; + auto I = RewriteMap.find(LHSUnknown->getValue()); + if (I != RewriteMap.end()) + Base = I->second; + RewriteMap[LHSUnknown->getValue()] = getUMinExpr(Base, RHS); + } + break; + } + case CmpInst::ICMP_EQ: + if (isa<SCEVConstant>(RHS)) + RewriteMap[LHSUnknown->getValue()] = RHS; + break; + case CmpInst::ICMP_NE: + if (isa<SCEVConstant>(RHS) && + cast<SCEVConstant>(RHS)->getValue()->isNullValue()) + RewriteMap[LHSUnknown->getValue()] = + getUMaxExpr(LHS, getOne(RHS->getType())); + break; + default: + break; + } + }; + // Starting at the loop predecessor, climb up the predecessor chain, as long + // as there are predecessors that can be found that have unique successors + // leading to the original header. + // TODO: share this logic with isLoopEntryGuardedByCond. + ValueToSCEVMapTy RewriteMap; + for (std::pair<const BasicBlock *, const BasicBlock *> Pair( + L->getLoopPredecessor(), L->getHeader()); + Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + + const BranchInst *LoopEntryPredicate = + dyn_cast<BranchInst>(Pair.first->getTerminator()); + if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional()) + continue; + + // TODO: use information from more complex conditions, e.g. AND expressions. + auto *Cmp = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition()); + if (!Cmp) + continue; + + auto Predicate = Cmp->getPredicate(); + if (LoopEntryPredicate->getSuccessor(1) == Pair.second) + Predicate = CmpInst::getInversePredicate(Predicate); + CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)), + getSCEV(Cmp->getOperand(1)), RewriteMap); + } + + // Also collect information from assumptions dominating the loop. + for (auto &AssumeVH : AC.assumptions()) { + if (!AssumeVH) + continue; + auto *AssumeI = cast<CallInst>(AssumeVH); + auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0)); + if (!Cmp || !DT.dominates(AssumeI, L->getHeader())) + continue; + CollectCondition(Cmp->getPredicate(), getSCEV(Cmp->getOperand(0)), + getSCEV(Cmp->getOperand(1)), RewriteMap); + } + + if (RewriteMap.empty()) + return Expr; + SCEVLoopGuardRewriter Rewriter(*this, RewriteMap); + return Rewriter.visit(Expr); +} |