diff options
| author | vvvv <[email protected]> | 2024-02-06 20:01:22 +0300 |
|---|---|---|
| committer | vvvv <[email protected]> | 2024-02-06 20:22:16 +0300 |
| commit | 0203b7a9a40828bb2bd4c32029b79ff0ea3d1f8f (patch) | |
| tree | e630d0d5bd0bd29fc8c2d2842ed2cfde781b993a /contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp | |
| parent | ba27db76d99d12a4f1c06960b5449423218614c4 (diff) | |
llvm16 targets
Diffstat (limited to 'contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp')
| -rw-r--r-- | contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp | 414 |
1 files changed, 414 insertions, 0 deletions
diff --git a/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp b/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp new file mode 100644 index 00000000000..2429b85cf86 --- /dev/null +++ b/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp @@ -0,0 +1,414 @@ +//===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// Insert tilecfg for each area of key AMX intrinsic. +/// All the key AMX intrinsic's tile operand must come from tileload. And the +/// def tile of key AMX intrinsic must be tilestored. +/// take tdpbssd for example: +/// -------------------------------------------------------------------------- +/// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key +/// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) | +/// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx +/// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) | +/// call void @llvm.x86.tilestored64.internal(... td) area +/// -------------------------------------------------------------------------- +/// This pass will insert tilecfg before every key-amx-area, some like: +/// -------------------------------------------------------------------------- +/// %cfgmem = alloca <16 x i32>, align 4 * allocate mem +/// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init +/// ... +/// ... pre-config shape of %t1 * +/// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * +/// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config +/// ... * +/// ... pre-config shape of %t2 * shapes +/// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * +/// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * +/// ... +/// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config +// +//===----------------------------------------------------------------------===// +// +#include "X86.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsX86.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "pre-amx-config" + +static bool isAMXIntrinsic(IntrinsicInst *II) { + for (Value *Operand : II->operands()) + if (Operand->getType()->isX86_AMXTy()) + return true; + return II->getType()->isX86_AMXTy(); +} + +static bool isTileLoad(IntrinsicInst *II) { + return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal || + II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal; +} + +static bool isTileStore(IntrinsicInst *II) { + return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal; +} + +#ifndef NDEBUG +static bool onlyTileDef(IntrinsicInst *II) { + for (Value *Operand : II->operands()) + if (Operand->getType()->isX86_AMXTy()) + return false; + return II->getType()->isX86_AMXTy(); +} + +static bool brokenVolatile(Instruction *I) { + // Todo: it is weak to identify a normal call here. + if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator()) + return true; + return false; +} +#endif + +namespace { +class X86PreAMXConfig { + using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>; + + Function &F; + +public: + X86PreAMXConfig(Function &Func) : F(Func) {} + bool preTileConfig(); + void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes); + bool findConfigShapes(PosAndShapesMap &PosAndShapes); + bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes); + void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder, + SmallVector<Value *, 8> &Shapes); + BasicBlock::iterator + getShapesAndConfigPosEnd(BasicBlock::iterator Iter, + SmallVector<Value *, 8> &Shapes); + bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store, + IntrinsicInst *KeyAMX); +}; + +// Orderly write the shapes in tilecfg's mem. This maybe not right. +// Because the first shape may not corresponding to the first tmm register, +// so we need to handle at at X86FastTileConfig::materializeTileCfg() +// after register allocation. +// For example: +// -------------------------------------------------------------------------- +// zeroinitialize tilecfg's mem (of ldtilecfg) +// -------------------------------------------------------------------------- +// ... pre-config shape of %t1 * +// %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 * +// %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 * +// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * +// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config +// ... * +// ... pre-config shape of %t2 * +// %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 * +// %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 * +// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes +// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * +// ... * +// ... pre-config shape of %t3 * of +// %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 * +// %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 * +// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * +// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * +// ... * tiles +// ... pre-config shape of %td * +// %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 * +// %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 * +// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * +// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * +// -------------------------------------------------------------------------- +// call void @llvm.x86.ldtilecfg(i8* %mem) * tile config +// -------------------------------------------------------------------------- +// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key +// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) +// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx +// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) +// call void @llvm.x86.tilestored64.internal(... td) area +// -------------------------------------------------------------------------- +void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder, + SmallVector<Value *, 8> &Shapes) { + LLVMContext &Ctx = Builder.getContext(); + Type *I8Ty = Type::getInt8Ty(Ctx); + Type *I16Ty = Type::getInt16Ty(Ctx); + + // TODO: Currently we defaultly set Palette = 1, it may be assigned to + // other value in the future. + Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0); + Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); + Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset); + Builder.CreateStore(PaletteValue, PalettePos); + + for (int I = 0, E = Shapes.size() / 2; I < E; I++) { + Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I); + Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2); + const std::string ShapeName = "amx.tmm." + itostr(I); + Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset, + ShapeName + ".shape.row"); + Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset); + ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0), + ShapeName + ".shape.col"); + Value *Row = Shapes[I * 2]; + Value *Col = Shapes[I * 2 + 1]; + Row = Builder.CreateTrunc(Row, I8Ty); + Builder.CreateStore(Row, RowPos); + Builder.CreateStore(Col, ColPos); + } +} + +void X86PreAMXConfig::addTileConfig(Instruction *ModelStart, + SmallVector<Value *, 8> &Shapes) { + Module *M = F.getParent(); + IRBuilder<> Builder(ModelStart); + const DataLayout &DL = M->getDataLayout(); + unsigned AddrSpace = DL.getAllocaAddrSpace(); + LLVMContext &Ctx = Builder.getContext(); + Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false); + Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx)); + + AllocaInst *Addr = + new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front()); + Addr->setAlignment(Alignment); + Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy()); + + Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment); + + preWriteTileCfg(I8Ptr, Builder, Shapes); + + Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, std::nullopt, + {I8Ptr}); +} + +// Todo: We may need to handle "more than one store" case in the future. +bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads, + IntrinsicInst *Store, + IntrinsicInst *KeyAMX) { + Value *ST = Store->getOperand(4); + + // Only has tileload and tilestore. + if (!KeyAMX) + return (Loads.size() == 1) && Loads.contains(ST); + + // All Loads should be operands of KeyAMX. + // All tile operands of KeyAMX should come from Loads. + for (Value *Op : KeyAMX->operands()) { + if (Op->getType()->isX86_AMXTy()) + if (!Loads.erase(Op)) + return false; + } + + // The def of KeyAMX should be stored into mem. + // Todo: is it key amx can be no def? + return Loads.empty() && (ST == cast<Value>(KeyAMX)); +} + +bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX, + SmallVector<Value *, 8> &Shapes) { + for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) { + Value *Op = KeyAMX->getOperand(I); + if (!Op->getType()->isX86_AMXTy()) + continue; + IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op); + assert((TileDef && isTileLoad(TileDef)) && + "All KeyAMX's tile definiation should comes from TileLoad!"); + Shapes.push_back(TileDef->getOperand(0)); + Shapes.push_back(TileDef->getOperand(1)); + } + if (!isTileStore(KeyAMX)) { + Shapes.push_back(KeyAMX->getOperand(0)); + Shapes.push_back(KeyAMX->getOperand(1)); + } + return Shapes.size() != 0; +} + +// Collect the shapes and skip the area of current key amx intrinsic. +// +// For example: +// ... +// -------------------------------------------------------------------------- +// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k) +// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k) +// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k) +// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) +// call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k) +// -------------------------------------------------------------------------- +BasicBlock::iterator +X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter, + SmallVector<Value *, 8> &Shapes) { + IntrinsicInst *KeyAMX = nullptr; + BasicBlock *BB = Iter->getParent(); + BasicBlock::iterator PosEnd = BB->end(); + SmallSet<Value *, 4> Loads; + + // See TileStore as "Config Position End" and check volatile model. + for (auto I = Iter, E = BB->end(); I != E; ++I) { + assert(!brokenVolatile(&*I) && "Not reach tile store!"); + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); + if (!II || !isAMXIntrinsic(II)) + continue; + + if (isTileLoad(II)) { + Loads.insert(II); + } else if (isTileStore(II)) { + if (!checkVolatileModel(Loads, II, KeyAMX)) + report_fatal_error("Not Volatile AMX Model!"); + PosEnd = I; + break; + } else { + assert(!KeyAMX && "Too many key amx intrinsic!"); + KeyAMX = II; + } + } + assert(PosEnd != BB->end() && "Not find TileStore!"); + + // See KeyAMX as TileStore if only TileLoad and TileStore. + if (!KeyAMX) + KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd); + + // Get Shapes in order. + assert(Shapes.empty() && "Shapes should be clean."); + getKeyAMXShapes(KeyAMX, Shapes); + + return PosEnd; +} + +// Record a key amx area's shapes with its position. +// Use the first tileload as its position. +// For example: +// ... +// -------------------------------------------------------------------------- +// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos +// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) / +// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes: +// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n) +// call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n) +// -------------------------------------------------------------------------- +bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) { + bool Find = false; + for (BasicBlock &BB : F) { + for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) { + IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I); + if (!II) + continue; + if (!isAMXIntrinsic(II)) + continue; + assert(onlyTileDef(II) && "Not volatile model for AMX at O0!"); + + I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]); + Find = true; + } + } + return Find; +} + +// Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic. +// e.g. (key amx = tdpbssd) +// -------------------------------------------------------------------------- +// %cfgmem = alloca <16 x i32>, align 4 * allocate mem +// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init +// ... +// ... pre-config shape of %t1 * +// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * +// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config +// ... * +// ... pre-config shape of %t2 * +// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes +// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * +// ... * +// ... pre-config shape of %t3 * of +// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * +// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * +// ... * tiles +// ... pre-config shape of %td * +// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * +// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * +// +// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config +// -------------------------------------------------------------------------- +// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key +// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) +// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx +// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) +// call void @llvm.x86.tilestored64.internal(... td) area +// -------------------------------------------------------------------------- +bool X86PreAMXConfig::preTileConfig() { + PosAndShapesMap PosAndShapes; + bool NeedCfg = findConfigShapes(PosAndShapes); + if (!NeedCfg) + return false; + for (auto &IPAndShapes : PosAndShapes) + addTileConfig(IPAndShapes.first, IPAndShapes.second); + + return true; +} +} // anonymous namespace + +namespace { + +class X86PreAMXConfigPass : public FunctionPass { +public: + static char ID; + + X86PreAMXConfigPass() : FunctionPass(ID) { + initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>(); + bool C = false; + + // Prepare for fast register allocation at O0. + if (TM->getOptLevel() == CodeGenOpt::None) { + + // We pre-config each key AMX intrinsic at O0. + // In theory, one tile config can cover several AMX intrinsics, but + // it is very diffcult to classify the tile shapes at O0. So here we + // let thing be easy, pre-config every key AMX intrinsic. + X86PreAMXConfig PCFG(F); + C = PCFG.preTileConfig(); + } + + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetPassConfig>(); + } +}; + +} // anonymous namespace + +static const char PassName[] = "Pre AMX Tile Config"; +char X86PreAMXConfigPass::ID = 0; +INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) +INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) +INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) + +FunctionPass *llvm::createX86PreAMXConfigPass() { + return new X86PreAMXConfigPass(); +} |
