diff options
author | orivej <orivej@yandex-team.ru> | 2022-02-10 16:45:01 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:45:01 +0300 |
commit | 2d37894b1b037cf24231090eda8589bbb44fb6fc (patch) | |
tree | be835aa92c6248212e705f25388ebafcf84bc7a1 /contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp | |
parent | 718c552901d703c502ccbefdfc3c9028d608b947 (diff) | |
download | ydb-2d37894b1b037cf24231090eda8589bbb44fb6fc.tar.gz |
Restoring authorship annotation for <orivej@yandex-team.ru>. Commit 2 of 2.
Diffstat (limited to 'contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp')
-rw-r--r-- | contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp | 1988 |
1 files changed, 994 insertions, 994 deletions
diff --git a/contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp b/contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp index 542dd3bdca..56823735e2 100644 --- a/contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp +++ b/contrib/libs/llvm12/lib/Target/ARM/MVEGatherScatterLowering.cpp @@ -1,177 +1,177 @@ -//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -/// This pass custom lowers llvm.gather and llvm.scatter instructions to -/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to -/// produce a better final result as we go. -// -//===----------------------------------------------------------------------===// - -#include "ARM.h" -#include "ARMBaseInstrInfo.h" -#include "ARMSubtarget.h" -#include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/CodeGen/TargetLowering.h" -#include "llvm/CodeGen/TargetPassConfig.h" -#include "llvm/CodeGen/TargetSubtargetInfo.h" -#include "llvm/InitializePasses.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/IntrinsicsARM.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" -#include "llvm/Pass.h" -#include "llvm/Support/Casting.h" -#include "llvm/Transforms/Utils/Local.h" -#include <algorithm> -#include <cassert> - -using namespace llvm; - +//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// This pass custom lowers llvm.gather and llvm.scatter instructions to +/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to +/// produce a better final result as we go. +// +//===----------------------------------------------------------------------===// + +#include "ARM.h" +#include "ARMBaseInstrInfo.h" +#include "ARMSubtarget.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/InitializePasses.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsARM.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> + +using namespace llvm; + #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" - -cl::opt<bool> EnableMaskedGatherScatters( + +cl::opt<bool> EnableMaskedGatherScatters( "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), - cl::desc("Enable the generation of masked gathers and scatters")); - -namespace { - -class MVEGatherScatterLowering : public FunctionPass { -public: - static char ID; // Pass identification, replacement for typeid - - explicit MVEGatherScatterLowering() : FunctionPass(ID) { - initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - StringRef getPassName() const override { - return "MVE gather/scatter lowering"; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesCFG(); - AU.addRequired<TargetPassConfig>(); - AU.addRequired<LoopInfoWrapperPass>(); - FunctionPass::getAnalysisUsage(AU); - } - -private: - LoopInfo *LI = nullptr; - - // Check this is a valid gather with correct alignment - bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, - Align Alignment); - // Check whether Ptr is hidden behind a bitcast and look through it - void lookThroughBitcast(Value *&Ptr); - // Check for a getelementptr and deduce base and offsets from it, on success - // returning the base directly and the offsets indirectly using the Offsets - // argument + cl::desc("Enable the generation of masked gathers and scatters")); + +namespace { + +class MVEGatherScatterLowering : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + explicit MVEGatherScatterLowering() : FunctionPass(ID) { + initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { + return "MVE gather/scatter lowering"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetPassConfig>(); + AU.addRequired<LoopInfoWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + +private: + LoopInfo *LI = nullptr; + + // Check this is a valid gather with correct alignment + bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, + Align Alignment); + // Check whether Ptr is hidden behind a bitcast and look through it + void lookThroughBitcast(Value *&Ptr); + // Check for a getelementptr and deduce base and offsets from it, on success + // returning the base directly and the offsets indirectly using the Offsets + // argument Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP, - IRBuilder<> &Builder); - // Compute the scale of this gather/scatter instruction - int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); - // If the value is a constant, or derived from constants via additions - // and multilications, return its numeric value - Optional<int64_t> getIfConst(const Value *V); - // If Inst is an add instruction, check whether one summand is a - // constant. If so, scale this constant and return it together with - // the other summand. - std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); - - Value *lowerGather(IntrinsicInst *I); - // Create a gather from a base + vector of offsets - Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, - Instruction *&Root, IRBuilder<> &Builder); - // Create a gather from a vector of pointers - Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, - IRBuilder<> &Builder, int64_t Increment = 0); - // Create an incrementing gather from a vector of pointers - Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, - IRBuilder<> &Builder, - int64_t Increment = 0); - - Value *lowerScatter(IntrinsicInst *I); - // Create a scatter to a base + vector of offsets - Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, - IRBuilder<> &Builder); - // Create a scatter to a vector of pointers - Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, - IRBuilder<> &Builder, - int64_t Increment = 0); - // Create an incrementing scatter from a vector of pointers - Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, - IRBuilder<> &Builder, - int64_t Increment = 0); - - // QI gathers and scatters can increment their offsets on their own if - // the increment is a constant value (digit) - Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, - Value *Ptr, GetElementPtrInst *GEP, - IRBuilder<> &Builder); - // QI gathers/scatters can increment their offsets on their own if the - // increment is a constant value (digit) - this creates a writeback QI - // gather/scatter - Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, - Value *Ptr, unsigned TypeScale, - IRBuilder<> &Builder); + IRBuilder<> &Builder); + // Compute the scale of this gather/scatter instruction + int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); + // If the value is a constant, or derived from constants via additions + // and multilications, return its numeric value + Optional<int64_t> getIfConst(const Value *V); + // If Inst is an add instruction, check whether one summand is a + // constant. If so, scale this constant and return it together with + // the other summand. + std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); + + Value *lowerGather(IntrinsicInst *I); + // Create a gather from a base + vector of offsets + Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, + Instruction *&Root, IRBuilder<> &Builder); + // Create a gather from a vector of pointers + Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, int64_t Increment = 0); + // Create an incrementing gather from a vector of pointers + Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment = 0); + + Value *lowerScatter(IntrinsicInst *I); + // Create a scatter to a base + vector of offsets + Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, + IRBuilder<> &Builder); + // Create a scatter to a vector of pointers + Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment = 0); + // Create an incrementing scatter from a vector of pointers + Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment = 0); + + // QI gathers and scatters can increment their offsets on their own if + // the increment is a constant value (digit) + Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, + Value *Ptr, GetElementPtrInst *GEP, + IRBuilder<> &Builder); + // QI gathers/scatters can increment their offsets on their own if the + // increment is a constant value (digit) - this creates a writeback QI + // gather/scatter + Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, + Value *Ptr, unsigned TypeScale, + IRBuilder<> &Builder); // Optimise the base and offsets of the given address bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); // Try to fold consecutive geps together into one Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder); - // Check whether these offsets could be moved out of the loop they're in - bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); - // Pushes the given add out of the loop - void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); - // Pushes the given mul out of the loop - void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, - Value *OffsSecondOperand, unsigned LoopIncrement, - IRBuilder<> &Builder); -}; - -} // end anonymous namespace - -char MVEGatherScatterLowering::ID = 0; - -INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, - "MVE gather/scattering lowering pass", false, false) - -Pass *llvm::createMVEGatherScatterLoweringPass() { - return new MVEGatherScatterLowering(); -} - -bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, - unsigned ElemSize, - Align Alignment) { - if (((NumElements == 4 && - (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || - (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || - (NumElements == 16 && ElemSize == 8)) && - Alignment >= ElemSize / 8) - return true; - LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " - << "valid alignment or vector type \n"); - return false; -} - + // Check whether these offsets could be moved out of the loop they're in + bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); + // Pushes the given add out of the loop + void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); + // Pushes the given mul out of the loop + void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, + Value *OffsSecondOperand, unsigned LoopIncrement, + IRBuilder<> &Builder); +}; + +} // end anonymous namespace + +char MVEGatherScatterLowering::ID = 0; + +INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, + "MVE gather/scattering lowering pass", false, false) + +Pass *llvm::createMVEGatherScatterLoweringPass() { + return new MVEGatherScatterLowering(); +} + +bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, + unsigned ElemSize, + Align Alignment) { + if (((NumElements == 4 && + (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || + (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || + (NumElements == 16 && ElemSize == 8)) && + Alignment >= ElemSize / 8) + return true; + LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " + << "valid alignment or vector type \n"); + return false; +} + static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { // Offsets that are not of type <N x i32> are sign extended by the // getelementptr instruction, and MVE gathers/scatters treat the offset as @@ -215,37 +215,37 @@ static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { } Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty, - GetElementPtrInst *GEP, - IRBuilder<> &Builder) { - if (!GEP) { - LLVM_DEBUG( - dbgs() << "masked gathers/scatters: no getelementpointer found\n"); - return nullptr; - } - LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." - << " Looking at intrinsic for base + vector of offsets\n"); - Value *GEPPtr = GEP->getPointerOperand(); + GetElementPtrInst *GEP, + IRBuilder<> &Builder) { + if (!GEP) { + LLVM_DEBUG( + dbgs() << "masked gathers/scatters: no getelementpointer found\n"); + return nullptr; + } + LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." + << " Looking at intrinsic for base + vector of offsets\n"); + Value *GEPPtr = GEP->getPointerOperand(); Offsets = GEP->getOperand(1); if (GEPPtr->getType()->isVectorTy() || !isa<FixedVectorType>(Offsets->getType())) - return nullptr; - - if (GEP->getNumOperands() != 2) { - LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" - << " operands. Expanding.\n"); - return nullptr; - } - Offsets = GEP->getOperand(1); + return nullptr; + + if (GEP->getNumOperands() != 2) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" + << " operands. Expanding.\n"); + return nullptr; + } + Offsets = GEP->getOperand(1); unsigned OffsetsElemCount = cast<FixedVectorType>(Offsets->getType())->getNumElements(); - // Paranoid check whether the number of parallel lanes is the same + // Paranoid check whether the number of parallel lanes is the same assert(Ty->getNumElements() == OffsetsElemCount); ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); if (ZextOffs) - Offsets = ZextOffs->getOperand(0); + Offsets = ZextOffs->getOperand(0); FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); - + // If the offsets are already being zext-ed to <N x i32>, that relieves us of // having to make sure that they won't overflow. if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) @@ -256,779 +256,779 @@ Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty, // The offset sizes have been checked; if any truncating or zext-ing is // required to fix them, do that now - if (Ty != Offsets->getType()) { + if (Ty != Offsets->getType()) { if ((Ty->getElementType()->getScalarSizeInBits() < OffsetType->getElementType()->getScalarSizeInBits())) { Offsets = Builder.CreateTrunc(Offsets, Ty); - } else { + } else { Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); - } - } - // If none of the checks failed, return the gep's base pointer - LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); - return GEPPtr; -} - -void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { - // Look through bitcast instruction if #elements is the same - if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { - auto *BCTy = cast<FixedVectorType>(BitCast->getType()); - auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); - if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { - LLVM_DEBUG( - dbgs() << "masked gathers/scatters: looking through bitcast\n"); - Ptr = BitCast->getOperand(0); - } - } -} - -int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, - unsigned MemoryElemSize) { - // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, - // or a 8bit, 16bit or 32bit load/store scaled by 1 - if (GEPElemSize == 32 && MemoryElemSize == 32) - return 2; - else if (GEPElemSize == 16 && MemoryElemSize == 16) - return 1; - else if (GEPElemSize == 8) - return 0; - LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " - << "create intrinsic\n"); - return -1; -} - -Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { - const Constant *C = dyn_cast<Constant>(V); - if (C != nullptr) - return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; - if (!isa<Instruction>(V)) - return Optional<int64_t>{}; - - const Instruction *I = cast<Instruction>(V); - if (I->getOpcode() == Instruction::Add || - I->getOpcode() == Instruction::Mul) { - Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); - Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); - if (!Op0 || !Op1) - return Optional<int64_t>{}; - if (I->getOpcode() == Instruction::Add) - return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; - if (I->getOpcode() == Instruction::Mul) - return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; - } - return Optional<int64_t>{}; -} - -std::pair<Value *, int64_t> -MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { - std::pair<Value *, int64_t> ReturnFalse = - std::pair<Value *, int64_t>(nullptr, 0); - // At this point, the instruction we're looking at must be an add or we - // bail out - Instruction *Add = dyn_cast<Instruction>(Inst); - if (Add == nullptr || Add->getOpcode() != Instruction::Add) - return ReturnFalse; - - Value *Summand; - Optional<int64_t> Const; - // Find out which operand the value that is increased is - if ((Const = getIfConst(Add->getOperand(0)))) - Summand = Add->getOperand(1); - else if ((Const = getIfConst(Add->getOperand(1)))) - Summand = Add->getOperand(0); - else - return ReturnFalse; - - // Check that the constant is small enough for an incrementing gather - int64_t Immediate = Const.getValue() << TypeScale; - if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) - return ReturnFalse; - - return std::pair<Value *, int64_t>(Summand, Immediate); -} - -Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { - using namespace PatternMatch; - LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); - - // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) - // Attempt to turn the masked gather in I into a MVE intrinsic - // Potentially optimising the addressing modes as we do so. - auto *Ty = cast<FixedVectorType>(I->getType()); - Value *Ptr = I->getArgOperand(0); - Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); - Value *Mask = I->getArgOperand(2); - Value *PassThru = I->getArgOperand(3); - - if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), - Alignment)) - return nullptr; - lookThroughBitcast(Ptr); - assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); - - IRBuilder<> Builder(I->getContext()); - Builder.SetInsertPoint(I); - Builder.SetCurrentDebugLocation(I->getDebugLoc()); - - Instruction *Root = I; - Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); - if (!Load) - Load = tryCreateMaskedGatherBase(I, Ptr, Builder); - if (!Load) - return nullptr; - - if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { - LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " - << "creating select\n"); - Load = Builder.CreateSelect(Mask, Load, PassThru); - } - - Root->replaceAllUsesWith(Load); - Root->eraseFromParent(); - if (Root != I) - // If this was an extending gather, we need to get rid of the sext/zext - // sext/zext as well as of the gather itself - I->eraseFromParent(); - - LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); - return Load; -} - -Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, - Value *Ptr, - IRBuilder<> &Builder, - int64_t Increment) { - using namespace PatternMatch; - auto *Ty = cast<FixedVectorType>(I->getType()); - LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); - if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) - // Can't build an intrinsic for this - return nullptr; - Value *Mask = I->getArgOperand(2); - if (match(Mask, m_One())) - return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, - {Ty, Ptr->getType()}, - {Ptr, Builder.getInt32(Increment)}); - else - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vldr_gather_base_predicated, - {Ty, Ptr->getType(), Mask->getType()}, - {Ptr, Builder.getInt32(Increment), Mask}); -} - -Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( - IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { - using namespace PatternMatch; - auto *Ty = cast<FixedVectorType>(I->getType()); - LLVM_DEBUG( - dbgs() - << "masked gathers: loading from vector of pointers with writeback\n"); - if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) - // Can't build an intrinsic for this - return nullptr; - Value *Mask = I->getArgOperand(2); - if (match(Mask, m_One())) - return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, - {Ty, Ptr->getType()}, - {Ptr, Builder.getInt32(Increment)}); - else - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vldr_gather_base_wb_predicated, - {Ty, Ptr->getType(), Mask->getType()}, - {Ptr, Builder.getInt32(Increment), Mask}); -} - -Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( - IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { - using namespace PatternMatch; - - Type *OriginalTy = I->getType(); - Type *ResultTy = OriginalTy; - - unsigned Unsigned = 1; - // The size of the gather was already checked in isLegalTypeAndAlignment; - // if it was not a full vector width an appropriate extend should follow. - auto *Extend = Root; - if (OriginalTy->getPrimitiveSizeInBits() < 128) { - // Only transform gathers with exactly one use - if (!I->hasOneUse()) - return nullptr; - - // The correct root to replace is not the CallInst itself, but the - // instruction which extends it - Extend = cast<Instruction>(*I->users().begin()); - if (isa<SExtInst>(Extend)) { - Unsigned = 0; - } else if (!isa<ZExtInst>(Extend)) { - LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " - << "Expanding\n"); - return nullptr; - } - LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); - ResultTy = Extend->getType(); - // The final size of the gather must be a full vector width - if (ResultTy->getPrimitiveSizeInBits() != 128) { - LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " - << "Expanding\n"); - return nullptr; - } - } - - GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); - Value *Offsets; + } + } + // If none of the checks failed, return the gep's base pointer + LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); + return GEPPtr; +} + +void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { + // Look through bitcast instruction if #elements is the same + if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { + auto *BCTy = cast<FixedVectorType>(BitCast->getType()); + auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); + if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { + LLVM_DEBUG( + dbgs() << "masked gathers/scatters: looking through bitcast\n"); + Ptr = BitCast->getOperand(0); + } + } +} + +int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, + unsigned MemoryElemSize) { + // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, + // or a 8bit, 16bit or 32bit load/store scaled by 1 + if (GEPElemSize == 32 && MemoryElemSize == 32) + return 2; + else if (GEPElemSize == 16 && MemoryElemSize == 16) + return 1; + else if (GEPElemSize == 8) + return 0; + LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " + << "create intrinsic\n"); + return -1; +} + +Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { + const Constant *C = dyn_cast<Constant>(V); + if (C != nullptr) + return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; + if (!isa<Instruction>(V)) + return Optional<int64_t>{}; + + const Instruction *I = cast<Instruction>(V); + if (I->getOpcode() == Instruction::Add || + I->getOpcode() == Instruction::Mul) { + Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); + Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); + if (!Op0 || !Op1) + return Optional<int64_t>{}; + if (I->getOpcode() == Instruction::Add) + return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; + if (I->getOpcode() == Instruction::Mul) + return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; + } + return Optional<int64_t>{}; +} + +std::pair<Value *, int64_t> +MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { + std::pair<Value *, int64_t> ReturnFalse = + std::pair<Value *, int64_t>(nullptr, 0); + // At this point, the instruction we're looking at must be an add or we + // bail out + Instruction *Add = dyn_cast<Instruction>(Inst); + if (Add == nullptr || Add->getOpcode() != Instruction::Add) + return ReturnFalse; + + Value *Summand; + Optional<int64_t> Const; + // Find out which operand the value that is increased is + if ((Const = getIfConst(Add->getOperand(0)))) + Summand = Add->getOperand(1); + else if ((Const = getIfConst(Add->getOperand(1)))) + Summand = Add->getOperand(0); + else + return ReturnFalse; + + // Check that the constant is small enough for an incrementing gather + int64_t Immediate = Const.getValue() << TypeScale; + if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) + return ReturnFalse; + + return std::pair<Value *, int64_t>(Summand, Immediate); +} + +Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { + using namespace PatternMatch; + LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); + + // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) + // Attempt to turn the masked gather in I into a MVE intrinsic + // Potentially optimising the addressing modes as we do so. + auto *Ty = cast<FixedVectorType>(I->getType()); + Value *Ptr = I->getArgOperand(0); + Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); + Value *Mask = I->getArgOperand(2); + Value *PassThru = I->getArgOperand(3); + + if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), + Alignment)) + return nullptr; + lookThroughBitcast(Ptr); + assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); + + IRBuilder<> Builder(I->getContext()); + Builder.SetInsertPoint(I); + Builder.SetCurrentDebugLocation(I->getDebugLoc()); + + Instruction *Root = I; + Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); + if (!Load) + Load = tryCreateMaskedGatherBase(I, Ptr, Builder); + if (!Load) + return nullptr; + + if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { + LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " + << "creating select\n"); + Load = Builder.CreateSelect(Mask, Load, PassThru); + } + + Root->replaceAllUsesWith(Load); + Root->eraseFromParent(); + if (Root != I) + // If this was an extending gather, we need to get rid of the sext/zext + // sext/zext as well as of the gather itself + I->eraseFromParent(); + + LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); + return Load; +} + +Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, + Value *Ptr, + IRBuilder<> &Builder, + int64_t Increment) { + using namespace PatternMatch; + auto *Ty = cast<FixedVectorType>(I->getType()); + LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); + if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) + // Can't build an intrinsic for this + return nullptr; + Value *Mask = I->getArgOperand(2); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, + {Ty, Ptr->getType()}, + {Ptr, Builder.getInt32(Increment)}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vldr_gather_base_predicated, + {Ty, Ptr->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { + using namespace PatternMatch; + auto *Ty = cast<FixedVectorType>(I->getType()); + LLVM_DEBUG( + dbgs() + << "masked gathers: loading from vector of pointers with writeback\n"); + if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) + // Can't build an intrinsic for this + return nullptr; + Value *Mask = I->getArgOperand(2); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, + {Ty, Ptr->getType()}, + {Ptr, Builder.getInt32(Increment)}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vldr_gather_base_wb_predicated, + {Ty, Ptr->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( + IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { + using namespace PatternMatch; + + Type *OriginalTy = I->getType(); + Type *ResultTy = OriginalTy; + + unsigned Unsigned = 1; + // The size of the gather was already checked in isLegalTypeAndAlignment; + // if it was not a full vector width an appropriate extend should follow. + auto *Extend = Root; + if (OriginalTy->getPrimitiveSizeInBits() < 128) { + // Only transform gathers with exactly one use + if (!I->hasOneUse()) + return nullptr; + + // The correct root to replace is not the CallInst itself, but the + // instruction which extends it + Extend = cast<Instruction>(*I->users().begin()); + if (isa<SExtInst>(Extend)) { + Unsigned = 0; + } else if (!isa<ZExtInst>(Extend)) { + LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " + << "Expanding\n"); + return nullptr; + } + LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); + ResultTy = Extend->getType(); + // The final size of the gather must be a full vector width + if (ResultTy->getPrimitiveSizeInBits() != 128) { + LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " + << "Expanding\n"); + return nullptr; + } + } + + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); + Value *Offsets; Value *BasePtr = checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder); - if (!BasePtr) - return nullptr; - // Check whether the offset is a constant increment that could be merged into - // a QI gather - Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); - if (Load) - return Load; - - int Scale = computeScale( - BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), - OriginalTy->getScalarSizeInBits()); - if (Scale == -1) - return nullptr; - Root = Extend; - - Value *Mask = I->getArgOperand(2); - if (!match(Mask, m_One())) - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vldr_gather_offset_predicated, - {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, - {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), - Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); - else - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vldr_gather_offset, - {ResultTy, BasePtr->getType(), Offsets->getType()}, - {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), - Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); -} - -Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { - using namespace PatternMatch; - LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); - - // @llvm.masked.scatter.*(data, ptrs, alignment, mask) - // Attempt to turn the masked scatter in I into a MVE intrinsic - // Potentially optimising the addressing modes as we do so. - Value *Input = I->getArgOperand(0); - Value *Ptr = I->getArgOperand(1); - Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); - auto *Ty = cast<FixedVectorType>(Input->getType()); - - if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), - Alignment)) - return nullptr; - - lookThroughBitcast(Ptr); - assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); - - IRBuilder<> Builder(I->getContext()); - Builder.SetInsertPoint(I); - Builder.SetCurrentDebugLocation(I->getDebugLoc()); - - Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); - if (!Store) - Store = tryCreateMaskedScatterBase(I, Ptr, Builder); - if (!Store) - return nullptr; - - LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); - I->eraseFromParent(); - return Store; -} - -Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( - IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { - using namespace PatternMatch; - Value *Input = I->getArgOperand(0); - auto *Ty = cast<FixedVectorType>(Input->getType()); - // Only QR variants allow truncating - if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { - // Can't build an intrinsic for this - return nullptr; - } - Value *Mask = I->getArgOperand(3); - // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) - LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); - if (match(Mask, m_One())) - return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, - {Ptr->getType(), Input->getType()}, - {Ptr, Builder.getInt32(Increment), Input}); - else - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vstr_scatter_base_predicated, - {Ptr->getType(), Input->getType(), Mask->getType()}, - {Ptr, Builder.getInt32(Increment), Input, Mask}); -} - -Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( - IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { - using namespace PatternMatch; - Value *Input = I->getArgOperand(0); - auto *Ty = cast<FixedVectorType>(Input->getType()); - LLVM_DEBUG( - dbgs() - << "masked scatters: storing to a vector of pointers with writeback\n"); - if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) - // Can't build an intrinsic for this - return nullptr; - Value *Mask = I->getArgOperand(3); - if (match(Mask, m_One())) - return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, - {Ptr->getType(), Input->getType()}, - {Ptr, Builder.getInt32(Increment), Input}); - else - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, - {Ptr->getType(), Input->getType(), Mask->getType()}, - {Ptr, Builder.getInt32(Increment), Input, Mask}); -} - -Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( - IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { - using namespace PatternMatch; - Value *Input = I->getArgOperand(0); - Value *Mask = I->getArgOperand(3); - Type *InputTy = Input->getType(); - Type *MemoryTy = InputTy; - LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" - << " to base + vector of offsets\n"); - // If the input has been truncated, try to integrate that trunc into the - // scatter instruction (we don't care about alignment here) - if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { - Value *PreTrunc = Trunc->getOperand(0); - Type *PreTruncTy = PreTrunc->getType(); - if (PreTruncTy->getPrimitiveSizeInBits() == 128) { - Input = PreTrunc; - InputTy = PreTruncTy; - } - } - if (InputTy->getPrimitiveSizeInBits() != 128) { - LLVM_DEBUG( - dbgs() << "masked scatters: cannot create scatters for non-standard" - << " input types. Expanding.\n"); - return nullptr; - } - - GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); - Value *Offsets; + if (!BasePtr) + return nullptr; + // Check whether the offset is a constant increment that could be merged into + // a QI gather + Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); + if (Load) + return Load; + + int Scale = computeScale( + BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), + OriginalTy->getScalarSizeInBits()); + if (Scale == -1) + return nullptr; + Root = Extend; + + Value *Mask = I->getArgOperand(2); + if (!match(Mask, m_One())) + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vldr_gather_offset_predicated, + {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, + {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), + Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vldr_gather_offset, + {ResultTy, BasePtr->getType(), Offsets->getType()}, + {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), + Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); +} + +Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { + using namespace PatternMatch; + LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); + + // @llvm.masked.scatter.*(data, ptrs, alignment, mask) + // Attempt to turn the masked scatter in I into a MVE intrinsic + // Potentially optimising the addressing modes as we do so. + Value *Input = I->getArgOperand(0); + Value *Ptr = I->getArgOperand(1); + Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); + auto *Ty = cast<FixedVectorType>(Input->getType()); + + if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), + Alignment)) + return nullptr; + + lookThroughBitcast(Ptr); + assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); + + IRBuilder<> Builder(I->getContext()); + Builder.SetInsertPoint(I); + Builder.SetCurrentDebugLocation(I->getDebugLoc()); + + Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); + if (!Store) + Store = tryCreateMaskedScatterBase(I, Ptr, Builder); + if (!Store) + return nullptr; + + LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); + I->eraseFromParent(); + return Store; +} + +Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { + using namespace PatternMatch; + Value *Input = I->getArgOperand(0); + auto *Ty = cast<FixedVectorType>(Input->getType()); + // Only QR variants allow truncating + if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { + // Can't build an intrinsic for this + return nullptr; + } + Value *Mask = I->getArgOperand(3); + // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) + LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, + {Ptr->getType(), Input->getType()}, + {Ptr, Builder.getInt32(Increment), Input}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_base_predicated, + {Ptr->getType(), Input->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Input, Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { + using namespace PatternMatch; + Value *Input = I->getArgOperand(0); + auto *Ty = cast<FixedVectorType>(Input->getType()); + LLVM_DEBUG( + dbgs() + << "masked scatters: storing to a vector of pointers with writeback\n"); + if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) + // Can't build an intrinsic for this + return nullptr; + Value *Mask = I->getArgOperand(3); + if (match(Mask, m_One())) + return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, + {Ptr->getType(), Input->getType()}, + {Ptr, Builder.getInt32(Increment), Input}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, + {Ptr->getType(), Input->getType(), Mask->getType()}, + {Ptr, Builder.getInt32(Increment), Input, Mask}); +} + +Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( + IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { + using namespace PatternMatch; + Value *Input = I->getArgOperand(0); + Value *Mask = I->getArgOperand(3); + Type *InputTy = Input->getType(); + Type *MemoryTy = InputTy; + LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" + << " to base + vector of offsets\n"); + // If the input has been truncated, try to integrate that trunc into the + // scatter instruction (we don't care about alignment here) + if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { + Value *PreTrunc = Trunc->getOperand(0); + Type *PreTruncTy = PreTrunc->getType(); + if (PreTruncTy->getPrimitiveSizeInBits() == 128) { + Input = PreTrunc; + InputTy = PreTruncTy; + } + } + if (InputTy->getPrimitiveSizeInBits() != 128) { + LLVM_DEBUG( + dbgs() << "masked scatters: cannot create scatters for non-standard" + << " input types. Expanding.\n"); + return nullptr; + } + + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); + Value *Offsets; Value *BasePtr = checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder); - if (!BasePtr) - return nullptr; - // Check whether the offset is a constant increment that could be merged into - // a QI gather - Value *Store = - tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); - if (Store) - return Store; - int Scale = computeScale( - BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), - MemoryTy->getScalarSizeInBits()); - if (Scale == -1) - return nullptr; - - if (!match(Mask, m_One())) - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vstr_scatter_offset_predicated, - {BasePtr->getType(), Offsets->getType(), Input->getType(), - Mask->getType()}, - {BasePtr, Offsets, Input, - Builder.getInt32(MemoryTy->getScalarSizeInBits()), - Builder.getInt32(Scale), Mask}); - else - return Builder.CreateIntrinsic( - Intrinsic::arm_mve_vstr_scatter_offset, - {BasePtr->getType(), Offsets->getType(), Input->getType()}, - {BasePtr, Offsets, Input, - Builder.getInt32(MemoryTy->getScalarSizeInBits()), - Builder.getInt32(Scale)}); -} - -Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( - IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, - IRBuilder<> &Builder) { - FixedVectorType *Ty; - if (I->getIntrinsicID() == Intrinsic::masked_gather) - Ty = cast<FixedVectorType>(I->getType()); - else - Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); - // Incrementing gathers only exist for v4i32 - if (Ty->getNumElements() != 4 || - Ty->getScalarSizeInBits() != 32) - return nullptr; - Loop *L = LI->getLoopFor(I->getParent()); - if (L == nullptr) - // Incrementing gathers are not beneficial outside of a loop - return nullptr; - LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " - "wb gather/scatter\n"); - - // The gep was in charge of making sure the offsets are scaled correctly - // - calculate that factor so it can be applied by hand - DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); - int TypeScale = - computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), - DT.getTypeSizeInBits(GEP->getType()) / - cast<FixedVectorType>(GEP->getType())->getNumElements()); - if (TypeScale == -1) - return nullptr; - - if (GEP->hasOneUse()) { - // Only in this case do we want to build a wb gather, because the wb will - // change the phi which does affect other users of the gep (which will still - // be using the phi in the old way) - Value *Load = - tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); - if (Load != nullptr) - return Load; - } - LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " - "non-wb gather/scatter\n"); - - std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); - if (Add.first == nullptr) - return nullptr; - Value *OffsetsIncoming = Add.first; - int64_t Immediate = Add.second; - - // Make sure the offsets are scaled correctly - Instruction *ScaledOffsets = BinaryOperator::Create( - Instruction::Shl, OffsetsIncoming, - Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), - "ScaledIndex", I); - // Add the base to the offsets - OffsetsIncoming = BinaryOperator::Create( - Instruction::Add, ScaledOffsets, - Builder.CreateVectorSplat( - Ty->getNumElements(), - Builder.CreatePtrToInt( - BasePtr, - cast<VectorType>(ScaledOffsets->getType())->getElementType())), - "StartIndex", I); - - if (I->getIntrinsicID() == Intrinsic::masked_gather) - return cast<IntrinsicInst>( - tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); - else - return cast<IntrinsicInst>( - tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); -} - -Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( - IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, - IRBuilder<> &Builder) { - // Check whether this gather's offset is incremented by a constant - if so, - // and the load is of the right type, we can merge this into a QI gather - Loop *L = LI->getLoopFor(I->getParent()); - // Offsets that are worth merging into this instruction will be incremented - // by a constant, thus we're looking for an add of a phi and a constant - PHINode *Phi = dyn_cast<PHINode>(Offsets); - if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || - Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) - // No phi means no IV to write back to; if there is a phi, we expect it - // to have exactly two incoming values; the only phis we are interested in - // will be loop IV's and have exactly two uses, one in their increment and - // one in the gather's gep - return nullptr; - - unsigned IncrementIndex = - Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; - // Look through the phi to the phi increment - Offsets = Phi->getIncomingValue(IncrementIndex); - - std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); - if (Add.first == nullptr) - return nullptr; - Value *OffsetsIncoming = Add.first; - int64_t Immediate = Add.second; - if (OffsetsIncoming != Phi) - // Then the increment we are looking at is not an increment of the - // induction variable, and we don't want to do a writeback - return nullptr; - - Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); - unsigned NumElems = - cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); - - // Make sure the offsets are scaled correctly - Instruction *ScaledOffsets = BinaryOperator::Create( - Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), - Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), - "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); - // Add the base to the offsets - OffsetsIncoming = BinaryOperator::Create( - Instruction::Add, ScaledOffsets, - Builder.CreateVectorSplat( - NumElems, - Builder.CreatePtrToInt( - BasePtr, - cast<VectorType>(ScaledOffsets->getType())->getElementType())), - "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); - // The gather is pre-incrementing - OffsetsIncoming = BinaryOperator::Create( - Instruction::Sub, OffsetsIncoming, - Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), - "PreIncrementStartIndex", - &Phi->getIncomingBlock(1 - IncrementIndex)->back()); - Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); - - Builder.SetInsertPoint(I); - - Value *EndResult; - Value *NewInduction; - if (I->getIntrinsicID() == Intrinsic::masked_gather) { - // Build the incrementing gather - Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); - // One value to be handed to whoever uses the gather, one is the loop - // increment - EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); - NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); - } else { - // Build the incrementing scatter - NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); - EndResult = NewInduction; - } - Instruction *AddInst = cast<Instruction>(Offsets); - AddInst->replaceAllUsesWith(NewInduction); - AddInst->eraseFromParent(); - Phi->setIncomingValue(IncrementIndex, NewInduction); - - return EndResult; -} - -void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, - Value *OffsSecondOperand, - unsigned StartIndex) { - LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); - Instruction *InsertionPoint = - &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); - // Initialize the phi with a vector that contains a sum of the constants - Instruction *NewIndex = BinaryOperator::Create( - Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, - "PushedOutAdd", InsertionPoint); - unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; - - // Order such that start index comes first (this reduces mov's) - Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); - Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), - Phi->getIncomingBlock(IncrementIndex)); - Phi->removeIncomingValue(IncrementIndex); - Phi->removeIncomingValue(StartIndex); -} - -void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, - Value *IncrementPerRound, - Value *OffsSecondOperand, - unsigned LoopIncrement, - IRBuilder<> &Builder) { - LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); - - // Create a new scalar add outside of the loop and transform it to a splat - // by which loop variable can be incremented - Instruction *InsertionPoint = &cast<Instruction>( - Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); - - // Create a new index - Value *StartIndex = BinaryOperator::Create( - Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), - OffsSecondOperand, "PushedOutMul", InsertionPoint); - - Instruction *Product = - BinaryOperator::Create(Instruction::Mul, IncrementPerRound, - OffsSecondOperand, "Product", InsertionPoint); - // Increment NewIndex by Product instead of the multiplication - Instruction *NewIncrement = BinaryOperator::Create( - Instruction::Add, Phi, Product, "IncrementPushedOutMul", - cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) - .getPrevNode()); - - Phi->addIncoming(StartIndex, - Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); - Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); - Phi->removeIncomingValue((unsigned)0); - Phi->removeIncomingValue((unsigned)0); -} - -// Check whether all usages of this instruction are as offsets of -// gathers/scatters or simple arithmetics only used by gathers/scatters -static bool hasAllGatScatUsers(Instruction *I) { - if (I->hasNUses(0)) { - return false; - } - bool Gatscat = true; - for (User *U : I->users()) { - if (!isa<Instruction>(U)) - return false; - if (isa<GetElementPtrInst>(U) || - isGatherScatter(dyn_cast<IntrinsicInst>(U))) { - return Gatscat; - } else { - unsigned OpCode = cast<Instruction>(U)->getOpcode(); - if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && - hasAllGatScatUsers(cast<Instruction>(U))) { - continue; - } - return false; - } - } - return Gatscat; -} - -bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, - LoopInfo *LI) { - LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); - // Optimise the addresses of gathers/scatters by moving invariant - // calculations out of the loop - if (!isa<Instruction>(Offsets)) - return false; - Instruction *Offs = cast<Instruction>(Offsets); - if (Offs->getOpcode() != Instruction::Add && - Offs->getOpcode() != Instruction::Mul) - return false; - Loop *L = LI->getLoopFor(BB); - if (L == nullptr) - return false; - if (!Offs->hasOneUse()) { - if (!hasAllGatScatUsers(Offs)) - return false; - } - - // Find out which, if any, operand of the instruction - // is a phi node - PHINode *Phi; - int OffsSecondOp; - if (isa<PHINode>(Offs->getOperand(0))) { - Phi = cast<PHINode>(Offs->getOperand(0)); - OffsSecondOp = 1; - } else if (isa<PHINode>(Offs->getOperand(1))) { - Phi = cast<PHINode>(Offs->getOperand(1)); - OffsSecondOp = 0; - } else { - bool Changed = true; - if (isa<Instruction>(Offs->getOperand(0)) && - L->contains(cast<Instruction>(Offs->getOperand(0)))) - Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); - if (isa<Instruction>(Offs->getOperand(1)) && - L->contains(cast<Instruction>(Offs->getOperand(1)))) - Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); - if (!Changed) { - return false; - } else { - if (isa<PHINode>(Offs->getOperand(0))) { - Phi = cast<PHINode>(Offs->getOperand(0)); - OffsSecondOp = 1; - } else if (isa<PHINode>(Offs->getOperand(1))) { - Phi = cast<PHINode>(Offs->getOperand(1)); - OffsSecondOp = 0; - } else { - return false; - } - } - } - // A phi node we want to perform this function on should be from the - // loop header, and shouldn't have more than 2 incoming values - if (Phi->getParent() != L->getHeader() || - Phi->getNumIncomingValues() != 2) - return false; - - // The phi must be an induction variable - int IncrementingBlock = -1; - - for (int i = 0; i < 2; i++) + if (!BasePtr) + return nullptr; + // Check whether the offset is a constant increment that could be merged into + // a QI gather + Value *Store = + tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); + if (Store) + return Store; + int Scale = computeScale( + BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), + MemoryTy->getScalarSizeInBits()); + if (Scale == -1) + return nullptr; + + if (!match(Mask, m_One())) + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_offset_predicated, + {BasePtr->getType(), Offsets->getType(), Input->getType(), + Mask->getType()}, + {BasePtr, Offsets, Input, + Builder.getInt32(MemoryTy->getScalarSizeInBits()), + Builder.getInt32(Scale), Mask}); + else + return Builder.CreateIntrinsic( + Intrinsic::arm_mve_vstr_scatter_offset, + {BasePtr->getType(), Offsets->getType(), Input->getType()}, + {BasePtr, Offsets, Input, + Builder.getInt32(MemoryTy->getScalarSizeInBits()), + Builder.getInt32(Scale)}); +} + +Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( + IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, + IRBuilder<> &Builder) { + FixedVectorType *Ty; + if (I->getIntrinsicID() == Intrinsic::masked_gather) + Ty = cast<FixedVectorType>(I->getType()); + else + Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); + // Incrementing gathers only exist for v4i32 + if (Ty->getNumElements() != 4 || + Ty->getScalarSizeInBits() != 32) + return nullptr; + Loop *L = LI->getLoopFor(I->getParent()); + if (L == nullptr) + // Incrementing gathers are not beneficial outside of a loop + return nullptr; + LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " + "wb gather/scatter\n"); + + // The gep was in charge of making sure the offsets are scaled correctly + // - calculate that factor so it can be applied by hand + DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); + int TypeScale = + computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), + DT.getTypeSizeInBits(GEP->getType()) / + cast<FixedVectorType>(GEP->getType())->getNumElements()); + if (TypeScale == -1) + return nullptr; + + if (GEP->hasOneUse()) { + // Only in this case do we want to build a wb gather, because the wb will + // change the phi which does affect other users of the gep (which will still + // be using the phi in the old way) + Value *Load = + tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); + if (Load != nullptr) + return Load; + } + LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " + "non-wb gather/scatter\n"); + + std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); + if (Add.first == nullptr) + return nullptr; + Value *OffsetsIncoming = Add.first; + int64_t Immediate = Add.second; + + // Make sure the offsets are scaled correctly + Instruction *ScaledOffsets = BinaryOperator::Create( + Instruction::Shl, OffsetsIncoming, + Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), + "ScaledIndex", I); + // Add the base to the offsets + OffsetsIncoming = BinaryOperator::Create( + Instruction::Add, ScaledOffsets, + Builder.CreateVectorSplat( + Ty->getNumElements(), + Builder.CreatePtrToInt( + BasePtr, + cast<VectorType>(ScaledOffsets->getType())->getElementType())), + "StartIndex", I); + + if (I->getIntrinsicID() == Intrinsic::masked_gather) + return cast<IntrinsicInst>( + tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); + else + return cast<IntrinsicInst>( + tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); +} + +Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( + IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, + IRBuilder<> &Builder) { + // Check whether this gather's offset is incremented by a constant - if so, + // and the load is of the right type, we can merge this into a QI gather + Loop *L = LI->getLoopFor(I->getParent()); + // Offsets that are worth merging into this instruction will be incremented + // by a constant, thus we're looking for an add of a phi and a constant + PHINode *Phi = dyn_cast<PHINode>(Offsets); + if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || + Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) + // No phi means no IV to write back to; if there is a phi, we expect it + // to have exactly two incoming values; the only phis we are interested in + // will be loop IV's and have exactly two uses, one in their increment and + // one in the gather's gep + return nullptr; + + unsigned IncrementIndex = + Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; + // Look through the phi to the phi increment + Offsets = Phi->getIncomingValue(IncrementIndex); + + std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); + if (Add.first == nullptr) + return nullptr; + Value *OffsetsIncoming = Add.first; + int64_t Immediate = Add.second; + if (OffsetsIncoming != Phi) + // Then the increment we are looking at is not an increment of the + // induction variable, and we don't want to do a writeback + return nullptr; + + Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); + unsigned NumElems = + cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); + + // Make sure the offsets are scaled correctly + Instruction *ScaledOffsets = BinaryOperator::Create( + Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), + Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), + "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); + // Add the base to the offsets + OffsetsIncoming = BinaryOperator::Create( + Instruction::Add, ScaledOffsets, + Builder.CreateVectorSplat( + NumElems, + Builder.CreatePtrToInt( + BasePtr, + cast<VectorType>(ScaledOffsets->getType())->getElementType())), + "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); + // The gather is pre-incrementing + OffsetsIncoming = BinaryOperator::Create( + Instruction::Sub, OffsetsIncoming, + Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), + "PreIncrementStartIndex", + &Phi->getIncomingBlock(1 - IncrementIndex)->back()); + Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); + + Builder.SetInsertPoint(I); + + Value *EndResult; + Value *NewInduction; + if (I->getIntrinsicID() == Intrinsic::masked_gather) { + // Build the incrementing gather + Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); + // One value to be handed to whoever uses the gather, one is the loop + // increment + EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); + NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); + } else { + // Build the incrementing scatter + NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); + EndResult = NewInduction; + } + Instruction *AddInst = cast<Instruction>(Offsets); + AddInst->replaceAllUsesWith(NewInduction); + AddInst->eraseFromParent(); + Phi->setIncomingValue(IncrementIndex, NewInduction); + + return EndResult; +} + +void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, + Value *OffsSecondOperand, + unsigned StartIndex) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); + Instruction *InsertionPoint = + &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); + // Initialize the phi with a vector that contains a sum of the constants + Instruction *NewIndex = BinaryOperator::Create( + Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, + "PushedOutAdd", InsertionPoint); + unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; + + // Order such that start index comes first (this reduces mov's) + Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); + Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), + Phi->getIncomingBlock(IncrementIndex)); + Phi->removeIncomingValue(IncrementIndex); + Phi->removeIncomingValue(StartIndex); +} + +void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, + Value *IncrementPerRound, + Value *OffsSecondOperand, + unsigned LoopIncrement, + IRBuilder<> &Builder) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); + + // Create a new scalar add outside of the loop and transform it to a splat + // by which loop variable can be incremented + Instruction *InsertionPoint = &cast<Instruction>( + Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); + + // Create a new index + Value *StartIndex = BinaryOperator::Create( + Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), + OffsSecondOperand, "PushedOutMul", InsertionPoint); + + Instruction *Product = + BinaryOperator::Create(Instruction::Mul, IncrementPerRound, + OffsSecondOperand, "Product", InsertionPoint); + // Increment NewIndex by Product instead of the multiplication + Instruction *NewIncrement = BinaryOperator::Create( + Instruction::Add, Phi, Product, "IncrementPushedOutMul", + cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) + .getPrevNode()); + + Phi->addIncoming(StartIndex, + Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); + Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); + Phi->removeIncomingValue((unsigned)0); + Phi->removeIncomingValue((unsigned)0); +} + +// Check whether all usages of this instruction are as offsets of +// gathers/scatters or simple arithmetics only used by gathers/scatters +static bool hasAllGatScatUsers(Instruction *I) { + if (I->hasNUses(0)) { + return false; + } + bool Gatscat = true; + for (User *U : I->users()) { + if (!isa<Instruction>(U)) + return false; + if (isa<GetElementPtrInst>(U) || + isGatherScatter(dyn_cast<IntrinsicInst>(U))) { + return Gatscat; + } else { + unsigned OpCode = cast<Instruction>(U)->getOpcode(); + if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && + hasAllGatScatUsers(cast<Instruction>(U))) { + continue; + } + return false; + } + } + return Gatscat; +} + +bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, + LoopInfo *LI) { + LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); + // Optimise the addresses of gathers/scatters by moving invariant + // calculations out of the loop + if (!isa<Instruction>(Offsets)) + return false; + Instruction *Offs = cast<Instruction>(Offsets); + if (Offs->getOpcode() != Instruction::Add && + Offs->getOpcode() != Instruction::Mul) + return false; + Loop *L = LI->getLoopFor(BB); + if (L == nullptr) + return false; + if (!Offs->hasOneUse()) { + if (!hasAllGatScatUsers(Offs)) + return false; + } + + // Find out which, if any, operand of the instruction + // is a phi node + PHINode *Phi; + int OffsSecondOp; + if (isa<PHINode>(Offs->getOperand(0))) { + Phi = cast<PHINode>(Offs->getOperand(0)); + OffsSecondOp = 1; + } else if (isa<PHINode>(Offs->getOperand(1))) { + Phi = cast<PHINode>(Offs->getOperand(1)); + OffsSecondOp = 0; + } else { + bool Changed = true; + if (isa<Instruction>(Offs->getOperand(0)) && + L->contains(cast<Instruction>(Offs->getOperand(0)))) + Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); + if (isa<Instruction>(Offs->getOperand(1)) && + L->contains(cast<Instruction>(Offs->getOperand(1)))) + Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); + if (!Changed) { + return false; + } else { + if (isa<PHINode>(Offs->getOperand(0))) { + Phi = cast<PHINode>(Offs->getOperand(0)); + OffsSecondOp = 1; + } else if (isa<PHINode>(Offs->getOperand(1))) { + Phi = cast<PHINode>(Offs->getOperand(1)); + OffsSecondOp = 0; + } else { + return false; + } + } + } + // A phi node we want to perform this function on should be from the + // loop header, and shouldn't have more than 2 incoming values + if (Phi->getParent() != L->getHeader() || + Phi->getNumIncomingValues() != 2) + return false; + + // The phi must be an induction variable + int IncrementingBlock = -1; + + for (int i = 0; i < 2; i++) if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) - if (Op->getOpcode() == Instruction::Add && - (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) - IncrementingBlock = i; - if (IncrementingBlock == -1) - return false; - - Instruction *IncInstruction = - cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); - - // If the phi is not used by anything else, we can just adapt it when - // replacing the instruction; if it is, we'll have to duplicate it - PHINode *NewPhi; - Value *IncrementPerRound = IncInstruction->getOperand( - (IncInstruction->getOperand(0) == Phi) ? 1 : 0); - - // Get the value that is added to/multiplied with the phi - Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); - + if (Op->getOpcode() == Instruction::Add && + (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) + IncrementingBlock = i; + if (IncrementingBlock == -1) + return false; + + Instruction *IncInstruction = + cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); + + // If the phi is not used by anything else, we can just adapt it when + // replacing the instruction; if it is, we'll have to duplicate it + PHINode *NewPhi; + Value *IncrementPerRound = IncInstruction->getOperand( + (IncInstruction->getOperand(0) == Phi) ? 1 : 0); + + // Get the value that is added to/multiplied with the phi + Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); + if (IncrementPerRound->getType() != OffsSecondOperand->getType() || !L->isLoopInvariant(OffsSecondOperand)) - // Something has gone wrong, abort - return false; - - // Only proceed if the increment per round is a constant or an instruction - // which does not originate from within the loop - if (!isa<Constant>(IncrementPerRound) && - !(isa<Instruction>(IncrementPerRound) && - !L->contains(cast<Instruction>(IncrementPerRound)))) - return false; - - if (Phi->getNumUses() == 2) { - // No other users -> reuse existing phi (One user is the instruction - // we're looking at, the other is the phi increment) - if (IncInstruction->getNumUses() != 1) { - // If the incrementing instruction does have more users than - // our phi, we need to copy it - IncInstruction = BinaryOperator::Create( - Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, - IncrementPerRound, "LoopIncrement", IncInstruction); - Phi->setIncomingValue(IncrementingBlock, IncInstruction); - } - NewPhi = Phi; - } else { - // There are other users -> create a new phi - NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); - std::vector<Value *> Increases; - // Copy the incoming values of the old phi - NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), - Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); - IncInstruction = BinaryOperator::Create( - Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, - IncrementPerRound, "LoopIncrement", IncInstruction); - NewPhi->addIncoming(IncInstruction, - Phi->getIncomingBlock(IncrementingBlock)); - IncrementingBlock = 1; - } - - IRBuilder<> Builder(BB->getContext()); - Builder.SetInsertPoint(Phi); - Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); - - switch (Offs->getOpcode()) { - case Instruction::Add: - pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); - break; - case Instruction::Mul: - pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, - Builder); - break; - default: - return false; - } - LLVM_DEBUG( - dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); - - // The instruction has now been "absorbed" into the phi value - Offs->replaceAllUsesWith(NewPhi); - if (Offs->hasNUses(0)) - Offs->eraseFromParent(); - // Clean up the old increment in case it's unused because we built a new - // one - if (IncInstruction->hasNUses(0)) - IncInstruction->eraseFromParent(); - - return true; -} - + // Something has gone wrong, abort + return false; + + // Only proceed if the increment per round is a constant or an instruction + // which does not originate from within the loop + if (!isa<Constant>(IncrementPerRound) && + !(isa<Instruction>(IncrementPerRound) && + !L->contains(cast<Instruction>(IncrementPerRound)))) + return false; + + if (Phi->getNumUses() == 2) { + // No other users -> reuse existing phi (One user is the instruction + // we're looking at, the other is the phi increment) + if (IncInstruction->getNumUses() != 1) { + // If the incrementing instruction does have more users than + // our phi, we need to copy it + IncInstruction = BinaryOperator::Create( + Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, + IncrementPerRound, "LoopIncrement", IncInstruction); + Phi->setIncomingValue(IncrementingBlock, IncInstruction); + } + NewPhi = Phi; + } else { + // There are other users -> create a new phi + NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); + std::vector<Value *> Increases; + // Copy the incoming values of the old phi + NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), + Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); + IncInstruction = BinaryOperator::Create( + Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, + IncrementPerRound, "LoopIncrement", IncInstruction); + NewPhi->addIncoming(IncInstruction, + Phi->getIncomingBlock(IncrementingBlock)); + IncrementingBlock = 1; + } + + IRBuilder<> Builder(BB->getContext()); + Builder.SetInsertPoint(Phi); + Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); + + switch (Offs->getOpcode()) { + case Instruction::Add: + pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); + break; + case Instruction::Mul: + pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, + Builder); + break; + default: + return false; + } + LLVM_DEBUG( + dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); + + // The instruction has now been "absorbed" into the phi value + Offs->replaceAllUsesWith(NewPhi); + if (Offs->hasNUses(0)) + Offs->eraseFromParent(); + // Clean up the old increment in case it's unused because we built a new + // one + if (IncInstruction->hasNUses(0)) + IncInstruction->eraseFromParent(); + + return true; +} + static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, IRBuilder<> &Builder) { // Splat the non-vector value to a vector of the given type - if the value is @@ -1151,56 +1151,56 @@ bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, return Changed; } -bool MVEGatherScatterLowering::runOnFunction(Function &F) { - if (!EnableMaskedGatherScatters) - return false; - auto &TPC = getAnalysis<TargetPassConfig>(); - auto &TM = TPC.getTM<TargetMachine>(); - auto *ST = &TM.getSubtarget<ARMSubtarget>(F); - if (!ST->hasMVEIntegerOps()) - return false; - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - SmallVector<IntrinsicInst *, 4> Gathers; - SmallVector<IntrinsicInst *, 4> Scatters; - - bool Changed = false; - - for (BasicBlock &BB : F) { +bool MVEGatherScatterLowering::runOnFunction(Function &F) { + if (!EnableMaskedGatherScatters) + return false; + auto &TPC = getAnalysis<TargetPassConfig>(); + auto &TM = TPC.getTM<TargetMachine>(); + auto *ST = &TM.getSubtarget<ARMSubtarget>(F); + if (!ST->hasMVEIntegerOps()) + return false; + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + SmallVector<IntrinsicInst *, 4> Gathers; + SmallVector<IntrinsicInst *, 4> Scatters; + + bool Changed = false; + + for (BasicBlock &BB : F) { Changed |= SimplifyInstructionsInBlock(&BB); - for (Instruction &I : BB) { - IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); + for (Instruction &I : BB) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); if (II && II->getIntrinsicID() == Intrinsic::masked_gather && isa<FixedVectorType>(II->getType())) { - Gathers.push_back(II); + Gathers.push_back(II); Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && isa<FixedVectorType>(II->getArgOperand(0)->getType())) { - Scatters.push_back(II); + Scatters.push_back(II); Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); - } - } - } - for (unsigned i = 0; i < Gathers.size(); i++) { - IntrinsicInst *I = Gathers[i]; - Value *L = lowerGather(I); - if (L == nullptr) - continue; - - // Get rid of any now dead instructions - SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); - Changed = true; - } - - for (unsigned i = 0; i < Scatters.size(); i++) { - IntrinsicInst *I = Scatters[i]; - Value *S = lowerScatter(I); - if (S == nullptr) - continue; - - // Get rid of any now dead instructions - SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); - Changed = true; - } - return Changed; -} + } + } + } + for (unsigned i = 0; i < Gathers.size(); i++) { + IntrinsicInst *I = Gathers[i]; + Value *L = lowerGather(I); + if (L == nullptr) + continue; + + // Get rid of any now dead instructions + SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); + Changed = true; + } + + for (unsigned i = 0; i < Scatters.size(); i++) { + IntrinsicInst *I = Scatters[i]; + Value *S = lowerScatter(I); + if (S == nullptr) + continue; + + // Get rid of any now dead instructions + SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); + Changed = true; + } + return Changed; +} |