diff options
author | shadchin <shadchin@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
---|---|---|
committer | Daniil Cherednik <dcherednik@yandex-team.ru> | 2022-02-10 16:44:39 +0300 |
commit | e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (patch) | |
tree | 64175d5cadab313b3e7039ebaa06c5bc3295e274 /contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp | |
parent | 2598ef1d0aee359b4b6d5fdd1758916d5907d04f (diff) | |
download | ydb-e9656aae26e0358d5378e5b63dcac5c8dbe0e4d0.tar.gz |
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 2 of 2.
Diffstat (limited to 'contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp')
-rw-r--r-- | contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp | 1896 |
1 files changed, 948 insertions, 948 deletions
diff --git a/contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index c8da464a3b..afa2d1bc79 100644 --- a/contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/contrib/libs/llvm12/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -1,948 +1,948 @@ -//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// -// instrinsics -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This pass replaces masked memory intrinsics - when unsupported by the target -// - with a chain of basic blocks, that deal with the elements one-by-one if the -// appropriate mask bit is set. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstrTypes.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Intrinsics.h" -#include "llvm/IR/Type.h" -#include "llvm/IR/Value.h" -#include "llvm/InitializePasses.h" -#include "llvm/Pass.h" -#include "llvm/Support/Casting.h" -#include "llvm/Transforms/Scalar.h" -#include <algorithm> -#include <cassert> - -using namespace llvm; - -#define DEBUG_TYPE "scalarize-masked-mem-intrin" - -namespace { - -class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { -public: - static char ID; // Pass identification, replacement for typeid - - explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { - initializeScalarizeMaskedMemIntrinLegacyPassPass( - *PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - StringRef getPassName() const override { - return "Scalarize Masked Memory Intrinsics"; - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetTransformInfoWrapperPass>(); - } -}; - -} // end anonymous namespace - -static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, - const TargetTransformInfo &TTI, const DataLayout &DL); -static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, - const TargetTransformInfo &TTI, - const DataLayout &DL); - -char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; - -INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, - "Scalarize unsupported masked memory intrinsics", false, - false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, - "Scalarize unsupported masked memory intrinsics", false, - false) - -FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { - return new ScalarizeMaskedMemIntrinLegacyPass(); -} - -static bool isConstantIntVector(Value *Mask) { - Constant *C = dyn_cast<Constant>(Mask); - if (!C) - return false; - - unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *CElt = C->getAggregateElement(i); - if (!CElt || !isa<ConstantInt>(CElt)) - return false; - } - - return true; -} - -// Translate a masked load intrinsic like -// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, -// <16 x i1> %mask, <16 x i32> %passthru) -// to a chain of basic blocks, with loading element one-by-one if -// the appropriate mask bit is set -// -// %1 = bitcast i8* %addr to i32* -// %2 = extractelement <16 x i1> %mask, i32 0 -// br i1 %2, label %cond.load, label %else -// -// cond.load: ; preds = %0 -// %3 = getelementptr i32* %1, i32 0 -// %4 = load i32* %3 -// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 -// br label %else -// -// else: ; preds = %0, %cond.load -// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] -// %6 = extractelement <16 x i1> %mask, i32 1 -// br i1 %6, label %cond.load1, label %else2 -// -// cond.load1: ; preds = %else -// %7 = getelementptr i32* %1, i32 1 -// %8 = load i32* %7 -// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 -// br label %else2 -// -// else2: ; preds = %else, %cond.load1 -// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] -// %10 = extractelement <16 x i1> %mask, i32 2 -// br i1 %10, label %cond.load4, label %else5 -// -static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { - Value *Ptr = CI->getArgOperand(0); - Value *Alignment = CI->getArgOperand(1); - Value *Mask = CI->getArgOperand(2); - Value *Src0 = CI->getArgOperand(3); - - const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); - VectorType *VecType = cast<FixedVectorType>(CI->getType()); - - Type *EltTy = VecType->getElementType(); - - IRBuilder<> Builder(CI->getContext()); - Instruction *InsertPt = CI; - BasicBlock *IfBlock = CI->getParent(); - - Builder.SetInsertPoint(InsertPt); - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - - // Short-cut if the mask is all-true. - if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { - Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); - CI->replaceAllUsesWith(NewI); - CI->eraseFromParent(); - return; - } - - // Adjust alignment for the scalar instruction. - const Align AdjustedAlignVal = - commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); - // Bitcast %addr from i8* to EltTy* - Type *NewPtrType = - EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); - Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); - unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); - - // The result vector - Value *VResult = Src0; - - if (isConstantIntVector(Mask)) { - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) - continue; - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); - LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); - VResult = Builder.CreateInsertElement(VResult, Load, Idx); - } - CI->replaceAllUsesWith(VResult); - CI->eraseFromParent(); - return; - } - - // If the mask is not v1i1, use scalar bit test operations. This generates - // better results on X86 at least. - Value *SclrMask; - if (VectorWidth != 1) { - Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); - SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); - } - - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - // Fill the "else" block, created in the previous iteration - // - // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] - // %mask_1 = and i16 %scalar_mask, i32 1 << Idx - // %cond = icmp ne i16 %mask_1, 0 - // br i1 %mask_1, label %cond.load, label %else - // - Value *Predicate; - if (VectorWidth != 1) { - Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); - Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), - Builder.getIntN(VectorWidth, 0)); - } else { - Predicate = Builder.CreateExtractElement(Mask, Idx); - } - - // Create "cond" block - // - // %EltAddr = getelementptr i32* %1, i32 0 - // %Elt = load i32* %EltAddr - // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx - // - BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), - "cond.load"); - Builder.SetInsertPoint(InsertPt); - - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); - LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); - Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); - - // Create "else" block, fill it in the next iteration - BasicBlock *NewIfBlock = - CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); - Builder.SetInsertPoint(InsertPt); - Instruction *OldBr = IfBlock->getTerminator(); - BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); - OldBr->eraseFromParent(); - BasicBlock *PrevIfBlock = IfBlock; - IfBlock = NewIfBlock; - - // Create the phi to join the new and previous value. - PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); - Phi->addIncoming(NewVResult, CondBlock); - Phi->addIncoming(VResult, PrevIfBlock); - VResult = Phi; - } - - CI->replaceAllUsesWith(VResult); - CI->eraseFromParent(); - - ModifiedDT = true; -} - -// Translate a masked store intrinsic, like -// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, -// <16 x i1> %mask) -// to a chain of basic blocks, that stores element one-by-one if -// the appropriate mask bit is set -// -// %1 = bitcast i8* %addr to i32* -// %2 = extractelement <16 x i1> %mask, i32 0 -// br i1 %2, label %cond.store, label %else -// -// cond.store: ; preds = %0 -// %3 = extractelement <16 x i32> %val, i32 0 -// %4 = getelementptr i32* %1, i32 0 -// store i32 %3, i32* %4 -// br label %else -// -// else: ; preds = %0, %cond.store -// %5 = extractelement <16 x i1> %mask, i32 1 -// br i1 %5, label %cond.store1, label %else2 -// -// cond.store1: ; preds = %else -// %6 = extractelement <16 x i32> %val, i32 1 -// %7 = getelementptr i32* %1, i32 1 -// store i32 %6, i32* %7 -// br label %else2 -// . . . -static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { - Value *Src = CI->getArgOperand(0); - Value *Ptr = CI->getArgOperand(1); - Value *Alignment = CI->getArgOperand(2); - Value *Mask = CI->getArgOperand(3); - - const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); - auto *VecType = cast<VectorType>(Src->getType()); - - Type *EltTy = VecType->getElementType(); - - IRBuilder<> Builder(CI->getContext()); - Instruction *InsertPt = CI; - BasicBlock *IfBlock = CI->getParent(); - Builder.SetInsertPoint(InsertPt); - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - - // Short-cut if the mask is all-true. - if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { - Builder.CreateAlignedStore(Src, Ptr, AlignVal); - CI->eraseFromParent(); - return; - } - - // Adjust alignment for the scalar instruction. - const Align AdjustedAlignVal = - commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); - // Bitcast %addr from i8* to EltTy* - Type *NewPtrType = - EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); - Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); - unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); - - if (isConstantIntVector(Mask)) { - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) - continue; - Value *OneElt = Builder.CreateExtractElement(Src, Idx); - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); - Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); - } - CI->eraseFromParent(); - return; - } - - // If the mask is not v1i1, use scalar bit test operations. This generates - // better results on X86 at least. - Value *SclrMask; - if (VectorWidth != 1) { - Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); - SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); - } - - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - // Fill the "else" block, created in the previous iteration - // - // %mask_1 = and i16 %scalar_mask, i32 1 << Idx - // %cond = icmp ne i16 %mask_1, 0 - // br i1 %mask_1, label %cond.store, label %else - // - Value *Predicate; - if (VectorWidth != 1) { - Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); - Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), - Builder.getIntN(VectorWidth, 0)); - } else { - Predicate = Builder.CreateExtractElement(Mask, Idx); - } - - // Create "cond" block - // - // %OneElt = extractelement <16 x i32> %Src, i32 Idx - // %EltAddr = getelementptr i32* %1, i32 0 - // %store i32 %OneElt, i32* %EltAddr - // - BasicBlock *CondBlock = - IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); - Builder.SetInsertPoint(InsertPt); - - Value *OneElt = Builder.CreateExtractElement(Src, Idx); - Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); - Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); - - // Create "else" block, fill it in the next iteration - BasicBlock *NewIfBlock = - CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); - Builder.SetInsertPoint(InsertPt); - Instruction *OldBr = IfBlock->getTerminator(); - BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); - OldBr->eraseFromParent(); - IfBlock = NewIfBlock; - } - CI->eraseFromParent(); - - ModifiedDT = true; -} - -// Translate a masked gather intrinsic like -// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, -// <16 x i1> %Mask, <16 x i32> %Src) -// to a chain of basic blocks, with loading element one-by-one if -// the appropriate mask bit is set -// -// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind -// %Mask0 = extractelement <16 x i1> %Mask, i32 0 -// br i1 %Mask0, label %cond.load, label %else -// -// cond.load: -// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 -// %Load0 = load i32, i32* %Ptr0, align 4 -// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 -// br label %else -// -// else: -// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] -// %Mask1 = extractelement <16 x i1> %Mask, i32 1 -// br i1 %Mask1, label %cond.load1, label %else2 -// -// cond.load1: -// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 -// %Load1 = load i32, i32* %Ptr1, align 4 -// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 -// br label %else2 -// . . . -// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src -// ret <16 x i32> %Result -static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { - Value *Ptrs = CI->getArgOperand(0); - Value *Alignment = CI->getArgOperand(1); - Value *Mask = CI->getArgOperand(2); - Value *Src0 = CI->getArgOperand(3); - - auto *VecType = cast<FixedVectorType>(CI->getType()); - Type *EltTy = VecType->getElementType(); - - IRBuilder<> Builder(CI->getContext()); - Instruction *InsertPt = CI; - BasicBlock *IfBlock = CI->getParent(); - Builder.SetInsertPoint(InsertPt); - MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); - - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - - // The result vector - Value *VResult = Src0; - unsigned VectorWidth = VecType->getNumElements(); - - // Shorten the way if the mask is a vector of constants. - if (isConstantIntVector(Mask)) { - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) - continue; - Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); - LoadInst *Load = - Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); - VResult = - Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); - } - CI->replaceAllUsesWith(VResult); - CI->eraseFromParent(); - return; - } - - // If the mask is not v1i1, use scalar bit test operations. This generates - // better results on X86 at least. - Value *SclrMask; - if (VectorWidth != 1) { - Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); - SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); - } - - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - // Fill the "else" block, created in the previous iteration - // - // %Mask1 = and i16 %scalar_mask, i32 1 << Idx - // %cond = icmp ne i16 %mask_1, 0 - // br i1 %Mask1, label %cond.load, label %else - // - - Value *Predicate; - if (VectorWidth != 1) { - Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); - Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), - Builder.getIntN(VectorWidth, 0)); - } else { - Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); - } - - // Create "cond" block - // - // %EltAddr = getelementptr i32* %1, i32 0 - // %Elt = load i32* %EltAddr - // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx - // - BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); - Builder.SetInsertPoint(InsertPt); - - Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); - LoadInst *Load = - Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); - Value *NewVResult = - Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); - - // Create "else" block, fill it in the next iteration - BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); - Builder.SetInsertPoint(InsertPt); - Instruction *OldBr = IfBlock->getTerminator(); - BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); - OldBr->eraseFromParent(); - BasicBlock *PrevIfBlock = IfBlock; - IfBlock = NewIfBlock; - - PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); - Phi->addIncoming(NewVResult, CondBlock); - Phi->addIncoming(VResult, PrevIfBlock); - VResult = Phi; - } - - CI->replaceAllUsesWith(VResult); - CI->eraseFromParent(); - - ModifiedDT = true; -} - -// Translate a masked scatter intrinsic, like -// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, -// <16 x i1> %Mask) -// to a chain of basic blocks, that stores element one-by-one if -// the appropriate mask bit is set. -// -// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind -// %Mask0 = extractelement <16 x i1> %Mask, i32 0 -// br i1 %Mask0, label %cond.store, label %else -// -// cond.store: -// %Elt0 = extractelement <16 x i32> %Src, i32 0 -// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 -// store i32 %Elt0, i32* %Ptr0, align 4 -// br label %else -// -// else: -// %Mask1 = extractelement <16 x i1> %Mask, i32 1 -// br i1 %Mask1, label %cond.store1, label %else2 -// -// cond.store1: -// %Elt1 = extractelement <16 x i32> %Src, i32 1 -// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 -// store i32 %Elt1, i32* %Ptr1, align 4 -// br label %else2 -// . . . -static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { - Value *Src = CI->getArgOperand(0); - Value *Ptrs = CI->getArgOperand(1); - Value *Alignment = CI->getArgOperand(2); - Value *Mask = CI->getArgOperand(3); - - auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); - - assert( - isa<VectorType>(Ptrs->getType()) && - isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && - "Vector of pointers is expected in masked scatter intrinsic"); - - IRBuilder<> Builder(CI->getContext()); - Instruction *InsertPt = CI; - BasicBlock *IfBlock = CI->getParent(); - Builder.SetInsertPoint(InsertPt); - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - - MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); - unsigned VectorWidth = SrcFVTy->getNumElements(); - - // Shorten the way if the mask is a vector of constants. - if (isConstantIntVector(Mask)) { - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) - continue; - Value *OneElt = - Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); - Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); - Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); - } - CI->eraseFromParent(); - return; - } - - // If the mask is not v1i1, use scalar bit test operations. This generates - // better results on X86 at least. - Value *SclrMask; - if (VectorWidth != 1) { - Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); - SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); - } - - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - // Fill the "else" block, created in the previous iteration - // - // %Mask1 = and i16 %scalar_mask, i32 1 << Idx - // %cond = icmp ne i16 %mask_1, 0 - // br i1 %Mask1, label %cond.store, label %else - // - Value *Predicate; - if (VectorWidth != 1) { - Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); - Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), - Builder.getIntN(VectorWidth, 0)); - } else { - Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); - } - - // Create "cond" block - // - // %Elt1 = extractelement <16 x i32> %Src, i32 1 - // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 - // %store i32 %Elt1, i32* %Ptr1 - // - BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); - Builder.SetInsertPoint(InsertPt); - - Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); - Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); - Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); - - // Create "else" block, fill it in the next iteration - BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); - Builder.SetInsertPoint(InsertPt); - Instruction *OldBr = IfBlock->getTerminator(); - BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); - OldBr->eraseFromParent(); - IfBlock = NewIfBlock; - } - CI->eraseFromParent(); - - ModifiedDT = true; -} - -static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { - Value *Ptr = CI->getArgOperand(0); - Value *Mask = CI->getArgOperand(1); - Value *PassThru = CI->getArgOperand(2); - - auto *VecType = cast<FixedVectorType>(CI->getType()); - - Type *EltTy = VecType->getElementType(); - - IRBuilder<> Builder(CI->getContext()); - Instruction *InsertPt = CI; - BasicBlock *IfBlock = CI->getParent(); - - Builder.SetInsertPoint(InsertPt); - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - - unsigned VectorWidth = VecType->getNumElements(); - - // The result vector - Value *VResult = PassThru; - - // Shorten the way if the mask is a vector of constants. - // Create a build_vector pattern, with loads/undefs as necessary and then - // shuffle blend with the pass through value. - if (isConstantIntVector(Mask)) { - unsigned MemIndex = 0; - VResult = UndefValue::get(VecType); - SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - Value *InsertElt; - if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { - InsertElt = UndefValue::get(EltTy); - ShuffleMask[Idx] = Idx + VectorWidth; - } else { - Value *NewPtr = - Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); - InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), - "Load" + Twine(Idx)); - ShuffleMask[Idx] = Idx; - ++MemIndex; - } - VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, - "Res" + Twine(Idx)); - } - VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); - CI->replaceAllUsesWith(VResult); - CI->eraseFromParent(); - return; - } - - // If the mask is not v1i1, use scalar bit test operations. This generates - // better results on X86 at least. - Value *SclrMask; - if (VectorWidth != 1) { - Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); - SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); - } - - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - // Fill the "else" block, created in the previous iteration - // - // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] - // %mask_1 = extractelement <16 x i1> %mask, i32 Idx - // br i1 %mask_1, label %cond.load, label %else - // - - Value *Predicate; - if (VectorWidth != 1) { - Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); - Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), - Builder.getIntN(VectorWidth, 0)); - } else { - Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); - } - - // Create "cond" block - // - // %EltAddr = getelementptr i32* %1, i32 0 - // %Elt = load i32* %EltAddr - // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx - // - BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), - "cond.load"); - Builder.SetInsertPoint(InsertPt); - - LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); - Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); - - // Move the pointer if there are more blocks to come. - Value *NewPtr; - if ((Idx + 1) != VectorWidth) - NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); - - // Create "else" block, fill it in the next iteration - BasicBlock *NewIfBlock = - CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); - Builder.SetInsertPoint(InsertPt); - Instruction *OldBr = IfBlock->getTerminator(); - BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); - OldBr->eraseFromParent(); - BasicBlock *PrevIfBlock = IfBlock; - IfBlock = NewIfBlock; - - // Create the phi to join the new and previous value. - PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); - ResultPhi->addIncoming(NewVResult, CondBlock); - ResultPhi->addIncoming(VResult, PrevIfBlock); - VResult = ResultPhi; - - // Add a PHI for the pointer if this isn't the last iteration. - if ((Idx + 1) != VectorWidth) { - PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); - PtrPhi->addIncoming(NewPtr, CondBlock); - PtrPhi->addIncoming(Ptr, PrevIfBlock); - Ptr = PtrPhi; - } - } - - CI->replaceAllUsesWith(VResult); - CI->eraseFromParent(); - - ModifiedDT = true; -} - -static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { - Value *Src = CI->getArgOperand(0); - Value *Ptr = CI->getArgOperand(1); - Value *Mask = CI->getArgOperand(2); - - auto *VecType = cast<FixedVectorType>(Src->getType()); - - IRBuilder<> Builder(CI->getContext()); - Instruction *InsertPt = CI; - BasicBlock *IfBlock = CI->getParent(); - - Builder.SetInsertPoint(InsertPt); - Builder.SetCurrentDebugLocation(CI->getDebugLoc()); - - Type *EltTy = VecType->getElementType(); - - unsigned VectorWidth = VecType->getNumElements(); - - // Shorten the way if the mask is a vector of constants. - if (isConstantIntVector(Mask)) { - unsigned MemIndex = 0; - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) - continue; - Value *OneElt = - Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); - Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); - Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); - ++MemIndex; - } - CI->eraseFromParent(); - return; - } - - // If the mask is not v1i1, use scalar bit test operations. This generates - // better results on X86 at least. - Value *SclrMask; - if (VectorWidth != 1) { - Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); - SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); - } - - for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { - // Fill the "else" block, created in the previous iteration - // - // %mask_1 = extractelement <16 x i1> %mask, i32 Idx - // br i1 %mask_1, label %cond.store, label %else - // - Value *Predicate; - if (VectorWidth != 1) { - Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); - Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), - Builder.getIntN(VectorWidth, 0)); - } else { - Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); - } - - // Create "cond" block - // - // %OneElt = extractelement <16 x i32> %Src, i32 Idx - // %EltAddr = getelementptr i32* %1, i32 0 - // %store i32 %OneElt, i32* %EltAddr - // - BasicBlock *CondBlock = - IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); - Builder.SetInsertPoint(InsertPt); - - Value *OneElt = Builder.CreateExtractElement(Src, Idx); - Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); - - // Move the pointer if there are more blocks to come. - Value *NewPtr; - if ((Idx + 1) != VectorWidth) - NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); - - // Create "else" block, fill it in the next iteration - BasicBlock *NewIfBlock = - CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); - Builder.SetInsertPoint(InsertPt); - Instruction *OldBr = IfBlock->getTerminator(); - BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); - OldBr->eraseFromParent(); - BasicBlock *PrevIfBlock = IfBlock; - IfBlock = NewIfBlock; - - // Add a PHI for the pointer if this isn't the last iteration. - if ((Idx + 1) != VectorWidth) { - PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); - PtrPhi->addIncoming(NewPtr, CondBlock); - PtrPhi->addIncoming(Ptr, PrevIfBlock); - Ptr = PtrPhi; - } - } - CI->eraseFromParent(); - - ModifiedDT = true; -} - -static bool runImpl(Function &F, const TargetTransformInfo &TTI) { - bool EverMadeChange = false; - bool MadeChange = true; - auto &DL = F.getParent()->getDataLayout(); - while (MadeChange) { - MadeChange = false; - for (Function::iterator I = F.begin(); I != F.end();) { - BasicBlock *BB = &*I++; - bool ModifiedDTOnIteration = false; - MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL); - - // Restart BB iteration if the dominator tree of the Function was changed - if (ModifiedDTOnIteration) - break; - } - - EverMadeChange |= MadeChange; - } - return EverMadeChange; -} - -bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { - auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - return runImpl(F, TTI); -} - -PreservedAnalyses -ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &TTI = AM.getResult<TargetIRAnalysis>(F); - if (!runImpl(F, TTI)) - return PreservedAnalyses::all(); - PreservedAnalyses PA; - PA.preserve<TargetIRAnalysis>(); - return PA; -} - -static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, - const TargetTransformInfo &TTI, - const DataLayout &DL) { - bool MadeChange = false; - - BasicBlock::iterator CurInstIterator = BB.begin(); - while (CurInstIterator != BB.end()) { - if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) - MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL); - if (ModifiedDT) - return true; - } - - return MadeChange; -} - -static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, - const TargetTransformInfo &TTI, - const DataLayout &DL) { - IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); - if (II) { - // The scalarization code below does not work for scalable vectors. - if (isa<ScalableVectorType>(II->getType()) || - any_of(II->arg_operands(), - [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) - return false; - - switch (II->getIntrinsicID()) { - default: - break; - case Intrinsic::masked_load: - // Scalarize unsupported vector masked load - if (TTI.isLegalMaskedLoad( - CI->getType(), - cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) - return false; - scalarizeMaskedLoad(CI, ModifiedDT); - return true; - case Intrinsic::masked_store: - if (TTI.isLegalMaskedStore( - CI->getArgOperand(0)->getType(), - cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) - return false; - scalarizeMaskedStore(CI, ModifiedDT); - return true; - case Intrinsic::masked_gather: { - unsigned AlignmentInt = - cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); - Type *LoadTy = CI->getType(); - Align Alignment = - DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy); - if (TTI.isLegalMaskedGather(LoadTy, Alignment)) - return false; - scalarizeMaskedGather(CI, ModifiedDT); - return true; - } - case Intrinsic::masked_scatter: { - unsigned AlignmentInt = - cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); - Type *StoreTy = CI->getArgOperand(0)->getType(); - Align Alignment = - DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy); - if (TTI.isLegalMaskedScatter(StoreTy, Alignment)) - return false; - scalarizeMaskedScatter(CI, ModifiedDT); - return true; - } - case Intrinsic::masked_expandload: - if (TTI.isLegalMaskedExpandLoad(CI->getType())) - return false; - scalarizeMaskedExpandLoad(CI, ModifiedDT); - return true; - case Intrinsic::masked_compressstore: - if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) - return false; - scalarizeMaskedCompressStore(CI, ModifiedDT); - return true; - } - } - - return false; -} +//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// +// instrinsics +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass replaces masked memory intrinsics - when unsupported by the target +// - with a chain of basic blocks, that deal with the elements one-by-one if the +// appropriate mask bit is set. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Transforms/Scalar.h" +#include <algorithm> +#include <cassert> + +using namespace llvm; + +#define DEBUG_TYPE "scalarize-masked-mem-intrin" + +namespace { + +class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { + initializeScalarizeMaskedMemIntrinLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { + return "Scalarize Masked Memory Intrinsics"; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetTransformInfoWrapperPass>(); + } +}; + +} // end anonymous namespace + +static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, + const TargetTransformInfo &TTI, const DataLayout &DL); +static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, + const TargetTransformInfo &TTI, + const DataLayout &DL); + +char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, + "Scalarize unsupported masked memory intrinsics", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, + "Scalarize unsupported masked memory intrinsics", false, + false) + +FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { + return new ScalarizeMaskedMemIntrinLegacyPass(); +} + +static bool isConstantIntVector(Value *Mask) { + Constant *C = dyn_cast<Constant>(Mask); + if (!C) + return false; + + unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *CElt = C->getAggregateElement(i); + if (!CElt || !isa<ConstantInt>(CElt)) + return false; + } + + return true; +} + +// Translate a masked load intrinsic like +// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, +// <16 x i1> %mask, <16 x i32> %passthru) +// to a chain of basic blocks, with loading element one-by-one if +// the appropriate mask bit is set +// +// %1 = bitcast i8* %addr to i32* +// %2 = extractelement <16 x i1> %mask, i32 0 +// br i1 %2, label %cond.load, label %else +// +// cond.load: ; preds = %0 +// %3 = getelementptr i32* %1, i32 0 +// %4 = load i32* %3 +// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 +// br label %else +// +// else: ; preds = %0, %cond.load +// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] +// %6 = extractelement <16 x i1> %mask, i32 1 +// br i1 %6, label %cond.load1, label %else2 +// +// cond.load1: ; preds = %else +// %7 = getelementptr i32* %1, i32 1 +// %8 = load i32* %7 +// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 +// br label %else2 +// +// else2: ; preds = %else, %cond.load1 +// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] +// %10 = extractelement <16 x i1> %mask, i32 2 +// br i1 %10, label %cond.load4, label %else5 +// +static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { + Value *Ptr = CI->getArgOperand(0); + Value *Alignment = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + Value *Src0 = CI->getArgOperand(3); + + const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); + VectorType *VecType = cast<FixedVectorType>(CI->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // Short-cut if the mask is all-true. + if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { + Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); + CI->replaceAllUsesWith(NewI); + CI->eraseFromParent(); + return; + } + + // Adjust alignment for the scalar instruction. + const Align AdjustedAlignVal = + commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); + // Bitcast %addr from i8* to EltTy* + Type *NewPtrType = + EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); + Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); + unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); + + // The result vector + Value *VResult = Src0; + + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); + VResult = Builder.CreateInsertElement(VResult, Load, Idx); + } + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + return; + } + + // If the mask is not v1i1, use scalar bit test operations. This generates + // better results on X86 at least. + Value *SclrMask; + if (VectorWidth != 1) { + Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); + SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] + // %mask_1 = and i16 %scalar_mask, i32 1 << Idx + // %cond = icmp ne i16 %mask_1, 0 + // br i1 %mask_1, label %cond.load, label %else + // + Value *Predicate; + if (VectorWidth != 1) { + Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); + Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), + Builder.getIntN(VectorWidth, 0)); + } else { + Predicate = Builder.CreateExtractElement(Mask, Idx); + } + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), + "cond.load"); + Builder.SetInsertPoint(InsertPt); + + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); + Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Create the phi to join the new and previous value. + PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + Phi->addIncoming(NewVResult, CondBlock); + Phi->addIncoming(VResult, PrevIfBlock); + VResult = Phi; + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +// Translate a masked store intrinsic, like +// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, +// <16 x i1> %mask) +// to a chain of basic blocks, that stores element one-by-one if +// the appropriate mask bit is set +// +// %1 = bitcast i8* %addr to i32* +// %2 = extractelement <16 x i1> %mask, i32 0 +// br i1 %2, label %cond.store, label %else +// +// cond.store: ; preds = %0 +// %3 = extractelement <16 x i32> %val, i32 0 +// %4 = getelementptr i32* %1, i32 0 +// store i32 %3, i32* %4 +// br label %else +// +// else: ; preds = %0, %cond.store +// %5 = extractelement <16 x i1> %mask, i32 1 +// br i1 %5, label %cond.store1, label %else2 +// +// cond.store1: ; preds = %else +// %6 = extractelement <16 x i32> %val, i32 1 +// %7 = getelementptr i32* %1, i32 1 +// store i32 %6, i32* %7 +// br label %else2 +// . . . +static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptr = CI->getArgOperand(1); + Value *Alignment = CI->getArgOperand(2); + Value *Mask = CI->getArgOperand(3); + + const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); + auto *VecType = cast<VectorType>(Src->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // Short-cut if the mask is all-true. + if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { + Builder.CreateAlignedStore(Src, Ptr, AlignVal); + CI->eraseFromParent(); + return; + } + + // Adjust alignment for the scalar instruction. + const Align AdjustedAlignVal = + commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); + // Bitcast %addr from i8* to EltTy* + Type *NewPtrType = + EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); + Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); + unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); + + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); + } + CI->eraseFromParent(); + return; + } + + // If the mask is not v1i1, use scalar bit test operations. This generates + // better results on X86 at least. + Value *SclrMask; + if (VectorWidth != 1) { + Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); + SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %mask_1 = and i16 %scalar_mask, i32 1 << Idx + // %cond = icmp ne i16 %mask_1, 0 + // br i1 %mask_1, label %cond.store, label %else + // + Value *Predicate; + if (VectorWidth != 1) { + Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); + Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), + Builder.getIntN(VectorWidth, 0)); + } else { + Predicate = Builder.CreateExtractElement(Mask, Idx); + } + + // Create "cond" block + // + // %OneElt = extractelement <16 x i32> %Src, i32 Idx + // %EltAddr = getelementptr i32* %1, i32 0 + // %store i32 %OneElt, i32* %EltAddr + // + BasicBlock *CondBlock = + IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); + Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + IfBlock = NewIfBlock; + } + CI->eraseFromParent(); + + ModifiedDT = true; +} + +// Translate a masked gather intrinsic like +// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, +// <16 x i1> %Mask, <16 x i32> %Src) +// to a chain of basic blocks, with loading element one-by-one if +// the appropriate mask bit is set +// +// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind +// %Mask0 = extractelement <16 x i1> %Mask, i32 0 +// br i1 %Mask0, label %cond.load, label %else +// +// cond.load: +// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 +// %Load0 = load i32, i32* %Ptr0, align 4 +// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 +// br label %else +// +// else: +// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] +// %Mask1 = extractelement <16 x i1> %Mask, i32 1 +// br i1 %Mask1, label %cond.load1, label %else2 +// +// cond.load1: +// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +// %Load1 = load i32, i32* %Ptr1, align 4 +// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 +// br label %else2 +// . . . +// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src +// ret <16 x i32> %Result +static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { + Value *Ptrs = CI->getArgOperand(0); + Value *Alignment = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + Value *Src0 = CI->getArgOperand(3); + + auto *VecType = cast<FixedVectorType>(CI->getType()); + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + Builder.SetInsertPoint(InsertPt); + MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); + + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + // The result vector + Value *VResult = Src0; + unsigned VectorWidth = VecType->getNumElements(); + + // Shorten the way if the mask is a vector of constants. + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + LoadInst *Load = + Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); + VResult = + Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); + } + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + return; + } + + // If the mask is not v1i1, use scalar bit test operations. This generates + // better results on X86 at least. + Value *SclrMask; + if (VectorWidth != 1) { + Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); + SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %Mask1 = and i16 %scalar_mask, i32 1 << Idx + // %cond = icmp ne i16 %mask_1, 0 + // br i1 %Mask1, label %cond.load, label %else + // + + Value *Predicate; + if (VectorWidth != 1) { + Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); + Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), + Builder.getIntN(VectorWidth, 0)); + } else { + Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + } + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); + Builder.SetInsertPoint(InsertPt); + + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + LoadInst *Load = + Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); + Value *NewVResult = + Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + Phi->addIncoming(NewVResult, CondBlock); + Phi->addIncoming(VResult, PrevIfBlock); + VResult = Phi; + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +// Translate a masked scatter intrinsic, like +// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, +// <16 x i1> %Mask) +// to a chain of basic blocks, that stores element one-by-one if +// the appropriate mask bit is set. +// +// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind +// %Mask0 = extractelement <16 x i1> %Mask, i32 0 +// br i1 %Mask0, label %cond.store, label %else +// +// cond.store: +// %Elt0 = extractelement <16 x i32> %Src, i32 0 +// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 +// store i32 %Elt0, i32* %Ptr0, align 4 +// br label %else +// +// else: +// %Mask1 = extractelement <16 x i1> %Mask, i32 1 +// br i1 %Mask1, label %cond.store1, label %else2 +// +// cond.store1: +// %Elt1 = extractelement <16 x i32> %Src, i32 1 +// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 +// store i32 %Elt1, i32* %Ptr1, align 4 +// br label %else2 +// . . . +static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptrs = CI->getArgOperand(1); + Value *Alignment = CI->getArgOperand(2); + Value *Mask = CI->getArgOperand(3); + + auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); + + assert( + isa<VectorType>(Ptrs->getType()) && + isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && + "Vector of pointers is expected in masked scatter intrinsic"); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); + unsigned VectorWidth = SrcFVTy->getNumElements(); + + // Shorten the way if the mask is a vector of constants. + if (isConstantIntVector(Mask)) { + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *OneElt = + Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); + } + CI->eraseFromParent(); + return; + } + + // If the mask is not v1i1, use scalar bit test operations. This generates + // better results on X86 at least. + Value *SclrMask; + if (VectorWidth != 1) { + Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); + SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %Mask1 = and i16 %scalar_mask, i32 1 << Idx + // %cond = icmp ne i16 %mask_1, 0 + // br i1 %Mask1, label %cond.store, label %else + // + Value *Predicate; + if (VectorWidth != 1) { + Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); + Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), + Builder.getIntN(VectorWidth, 0)); + } else { + Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + } + + // Create "cond" block + // + // %Elt1 = extractelement <16 x i32> %Src, i32 1 + // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 + // %store i32 %Elt1, i32* %Ptr1 + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); + Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + IfBlock = NewIfBlock; + } + CI->eraseFromParent(); + + ModifiedDT = true; +} + +static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { + Value *Ptr = CI->getArgOperand(0); + Value *Mask = CI->getArgOperand(1); + Value *PassThru = CI->getArgOperand(2); + + auto *VecType = cast<FixedVectorType>(CI->getType()); + + Type *EltTy = VecType->getElementType(); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + unsigned VectorWidth = VecType->getNumElements(); + + // The result vector + Value *VResult = PassThru; + + // Shorten the way if the mask is a vector of constants. + // Create a build_vector pattern, with loads/undefs as necessary and then + // shuffle blend with the pass through value. + if (isConstantIntVector(Mask)) { + unsigned MemIndex = 0; + VResult = UndefValue::get(VecType); + SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + Value *InsertElt; + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { + InsertElt = UndefValue::get(EltTy); + ShuffleMask[Idx] = Idx + VectorWidth; + } else { + Value *NewPtr = + Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); + InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), + "Load" + Twine(Idx)); + ShuffleMask[Idx] = Idx; + ++MemIndex; + } + VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, + "Res" + Twine(Idx)); + } + VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + return; + } + + // If the mask is not v1i1, use scalar bit test operations. This generates + // better results on X86 at least. + Value *SclrMask; + if (VectorWidth != 1) { + Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); + SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.load, label %else + // + + Value *Predicate; + if (VectorWidth != 1) { + Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); + Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), + Builder.getIntN(VectorWidth, 0)); + } else { + Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + } + + // Create "cond" block + // + // %EltAddr = getelementptr i32* %1, i32 0 + // %Elt = load i32* %EltAddr + // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx + // + BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), + "cond.load"); + Builder.SetInsertPoint(InsertPt); + + LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); + Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); + + // Move the pointer if there are more blocks to come. + Value *NewPtr; + if ((Idx + 1) != VectorWidth) + NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Create the phi to join the new and previous value. + PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); + ResultPhi->addIncoming(NewVResult, CondBlock); + ResultPhi->addIncoming(VResult, PrevIfBlock); + VResult = ResultPhi; + + // Add a PHI for the pointer if this isn't the last iteration. + if ((Idx + 1) != VectorWidth) { + PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); + PtrPhi->addIncoming(NewPtr, CondBlock); + PtrPhi->addIncoming(Ptr, PrevIfBlock); + Ptr = PtrPhi; + } + } + + CI->replaceAllUsesWith(VResult); + CI->eraseFromParent(); + + ModifiedDT = true; +} + +static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { + Value *Src = CI->getArgOperand(0); + Value *Ptr = CI->getArgOperand(1); + Value *Mask = CI->getArgOperand(2); + + auto *VecType = cast<FixedVectorType>(Src->getType()); + + IRBuilder<> Builder(CI->getContext()); + Instruction *InsertPt = CI; + BasicBlock *IfBlock = CI->getParent(); + + Builder.SetInsertPoint(InsertPt); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + Type *EltTy = VecType->getElementType(); + + unsigned VectorWidth = VecType->getNumElements(); + + // Shorten the way if the mask is a vector of constants. + if (isConstantIntVector(Mask)) { + unsigned MemIndex = 0; + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) + continue; + Value *OneElt = + Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); + Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); + Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); + ++MemIndex; + } + CI->eraseFromParent(); + return; + } + + // If the mask is not v1i1, use scalar bit test operations. This generates + // better results on X86 at least. + Value *SclrMask; + if (VectorWidth != 1) { + Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); + SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); + } + + for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { + // Fill the "else" block, created in the previous iteration + // + // %mask_1 = extractelement <16 x i1> %mask, i32 Idx + // br i1 %mask_1, label %cond.store, label %else + // + Value *Predicate; + if (VectorWidth != 1) { + Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); + Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), + Builder.getIntN(VectorWidth, 0)); + } else { + Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); + } + + // Create "cond" block + // + // %OneElt = extractelement <16 x i32> %Src, i32 Idx + // %EltAddr = getelementptr i32* %1, i32 0 + // %store i32 %OneElt, i32* %EltAddr + // + BasicBlock *CondBlock = + IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); + Builder.SetInsertPoint(InsertPt); + + Value *OneElt = Builder.CreateExtractElement(Src, Idx); + Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); + + // Move the pointer if there are more blocks to come. + Value *NewPtr; + if ((Idx + 1) != VectorWidth) + NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); + + // Create "else" block, fill it in the next iteration + BasicBlock *NewIfBlock = + CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); + Builder.SetInsertPoint(InsertPt); + Instruction *OldBr = IfBlock->getTerminator(); + BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); + OldBr->eraseFromParent(); + BasicBlock *PrevIfBlock = IfBlock; + IfBlock = NewIfBlock; + + // Add a PHI for the pointer if this isn't the last iteration. + if ((Idx + 1) != VectorWidth) { + PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); + PtrPhi->addIncoming(NewPtr, CondBlock); + PtrPhi->addIncoming(Ptr, PrevIfBlock); + Ptr = PtrPhi; + } + } + CI->eraseFromParent(); + + ModifiedDT = true; +} + +static bool runImpl(Function &F, const TargetTransformInfo &TTI) { + bool EverMadeChange = false; + bool MadeChange = true; + auto &DL = F.getParent()->getDataLayout(); + while (MadeChange) { + MadeChange = false; + for (Function::iterator I = F.begin(); I != F.end();) { + BasicBlock *BB = &*I++; + bool ModifiedDTOnIteration = false; + MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL); + + // Restart BB iteration if the dominator tree of the Function was changed + if (ModifiedDTOnIteration) + break; + } + + EverMadeChange |= MadeChange; + } + return EverMadeChange; +} + +bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return runImpl(F, TTI); +} + +PreservedAnalyses +ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + if (!runImpl(F, TTI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<TargetIRAnalysis>(); + return PA; +} + +static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, + const TargetTransformInfo &TTI, + const DataLayout &DL) { + bool MadeChange = false; + + BasicBlock::iterator CurInstIterator = BB.begin(); + while (CurInstIterator != BB.end()) { + if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) + MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL); + if (ModifiedDT) + return true; + } + + return MadeChange; +} + +static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, + const TargetTransformInfo &TTI, + const DataLayout &DL) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); + if (II) { + // The scalarization code below does not work for scalable vectors. + if (isa<ScalableVectorType>(II->getType()) || + any_of(II->arg_operands(), + [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) + return false; + + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::masked_load: + // Scalarize unsupported vector masked load + if (TTI.isLegalMaskedLoad( + CI->getType(), + cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) + return false; + scalarizeMaskedLoad(CI, ModifiedDT); + return true; + case Intrinsic::masked_store: + if (TTI.isLegalMaskedStore( + CI->getArgOperand(0)->getType(), + cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) + return false; + scalarizeMaskedStore(CI, ModifiedDT); + return true; + case Intrinsic::masked_gather: { + unsigned AlignmentInt = + cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); + Type *LoadTy = CI->getType(); + Align Alignment = + DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy); + if (TTI.isLegalMaskedGather(LoadTy, Alignment)) + return false; + scalarizeMaskedGather(CI, ModifiedDT); + return true; + } + case Intrinsic::masked_scatter: { + unsigned AlignmentInt = + cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); + Type *StoreTy = CI->getArgOperand(0)->getType(); + Align Alignment = + DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy); + if (TTI.isLegalMaskedScatter(StoreTy, Alignment)) + return false; + scalarizeMaskedScatter(CI, ModifiedDT); + return true; + } + case Intrinsic::masked_expandload: + if (TTI.isLegalMaskedExpandLoad(CI->getType())) + return false; + scalarizeMaskedExpandLoad(CI, ModifiedDT); + return true; + case Intrinsic::masked_compressstore: + if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) + return false; + scalarizeMaskedCompressStore(CI, ModifiedDT); + return true; + } + } + + return false; +} |