summaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp
diff options
context:
space:
mode:
authorvvvv <[email protected]>2024-02-06 20:01:22 +0300
committervvvv <[email protected]>2024-02-06 20:22:16 +0300
commit0203b7a9a40828bb2bd4c32029b79ff0ea3d1f8f (patch)
treee630d0d5bd0bd29fc8c2d2842ed2cfde781b993a /contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp
parentba27db76d99d12a4f1c06960b5449423218614c4 (diff)
llvm16 targets
Diffstat (limited to 'contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp')
-rw-r--r--contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp414
1 files changed, 414 insertions, 0 deletions
diff --git a/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp b/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp
new file mode 100644
index 00000000000..2429b85cf86
--- /dev/null
+++ b/contrib/libs/llvm16/lib/Target/X86/X86PreAMXConfig.cpp
@@ -0,0 +1,414 @@
+//===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// Insert tilecfg for each area of key AMX intrinsic.
+/// All the key AMX intrinsic's tile operand must come from tileload. And the
+/// def tile of key AMX intrinsic must be tilestored.
+/// take tdpbssd for example:
+/// --------------------------------------------------------------------------
+/// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key
+/// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) |
+/// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx
+/// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) |
+/// call void @llvm.x86.tilestored64.internal(... td) area
+/// --------------------------------------------------------------------------
+/// This pass will insert tilecfg before every key-amx-area, some like:
+/// --------------------------------------------------------------------------
+/// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
+/// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
+/// ...
+/// ... pre-config shape of %t1 *
+/// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
+/// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
+/// ... *
+/// ... pre-config shape of %t2 * shapes
+/// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 *
+/// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
+/// ...
+/// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config
+//
+//===----------------------------------------------------------------------===//
+//
+#include "X86.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Target/TargetMachine.h"
+
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "pre-amx-config"
+
+static bool isAMXIntrinsic(IntrinsicInst *II) {
+ for (Value *Operand : II->operands())
+ if (Operand->getType()->isX86_AMXTy())
+ return true;
+ return II->getType()->isX86_AMXTy();
+}
+
+static bool isTileLoad(IntrinsicInst *II) {
+ return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal ||
+ II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal;
+}
+
+static bool isTileStore(IntrinsicInst *II) {
+ return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal;
+}
+
+#ifndef NDEBUG
+static bool onlyTileDef(IntrinsicInst *II) {
+ for (Value *Operand : II->operands())
+ if (Operand->getType()->isX86_AMXTy())
+ return false;
+ return II->getType()->isX86_AMXTy();
+}
+
+static bool brokenVolatile(Instruction *I) {
+ // Todo: it is weak to identify a normal call here.
+ if ((isa<CallInst>(I) && !isa<IntrinsicInst>(I)) || I->isTerminator())
+ return true;
+ return false;
+}
+#endif
+
+namespace {
+class X86PreAMXConfig {
+ using PosAndShapesMap = MapVector<Instruction *, SmallVector<Value *, 8>>;
+
+ Function &F;
+
+public:
+ X86PreAMXConfig(Function &Func) : F(Func) {}
+ bool preTileConfig();
+ void addTileConfig(Instruction *ModelStart, SmallVector<Value *, 8> &Shapes);
+ bool findConfigShapes(PosAndShapesMap &PosAndShapes);
+ bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector<Value *, 8> &Shapes);
+ void preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
+ SmallVector<Value *, 8> &Shapes);
+ BasicBlock::iterator
+ getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
+ SmallVector<Value *, 8> &Shapes);
+ bool checkVolatileModel(SmallSet<Value *, 4> &Loads, IntrinsicInst *Store,
+ IntrinsicInst *KeyAMX);
+};
+
+// Orderly write the shapes in tilecfg's mem. This maybe not right.
+// Because the first shape may not corresponding to the first tmm register,
+// so we need to handle at at X86FastTileConfig::materializeTileCfg()
+// after register allocation.
+// For example:
+// --------------------------------------------------------------------------
+// zeroinitialize tilecfg's mem (of ldtilecfg)
+// --------------------------------------------------------------------------
+// ... pre-config shape of %t1 *
+// %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 *
+// %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 *
+// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
+// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
+// ... *
+// ... pre-config shape of %t2 *
+// %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 *
+// %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 *
+// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
+// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
+// ... *
+// ... pre-config shape of %t3 * of
+// %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 *
+// %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 *
+// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
+// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
+// ... * tiles
+// ... pre-config shape of %td *
+// %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 *
+// %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 *
+// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
+// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
+// --------------------------------------------------------------------------
+// call void @llvm.x86.ldtilecfg(i8* %mem) * tile config
+// --------------------------------------------------------------------------
+// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
+// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
+// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
+// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
+// call void @llvm.x86.tilestored64.internal(... td) area
+// --------------------------------------------------------------------------
+void X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, IRBuilderBase &Builder,
+ SmallVector<Value *, 8> &Shapes) {
+ LLVMContext &Ctx = Builder.getContext();
+ Type *I8Ty = Type::getInt8Ty(Ctx);
+ Type *I16Ty = Type::getInt16Ty(Ctx);
+
+ // TODO: Currently we defaultly set Palette = 1, it may be assigned to
+ // other value in the future.
+ Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0);
+ Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
+ Value *PalettePos = Builder.CreateGEP(I8Ty, I8Ptr, PaletteOffset);
+ Builder.CreateStore(PaletteValue, PalettePos);
+
+ for (int I = 0, E = Shapes.size() / 2; I < E; I++) {
+ Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I);
+ Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2);
+ const std::string ShapeName = "amx.tmm." + itostr(I);
+ Value *RowPos = Builder.CreateGEP(I8Ty, I8Ptr, RowOffset,
+ ShapeName + ".shape.row");
+ Value *ColPos = Builder.CreateGEP(I8Ty, I8Ptr, ColOffset);
+ ColPos = Builder.CreateBitCast(ColPos, PointerType::get(I16Ty, 0),
+ ShapeName + ".shape.col");
+ Value *Row = Shapes[I * 2];
+ Value *Col = Shapes[I * 2 + 1];
+ Row = Builder.CreateTrunc(Row, I8Ty);
+ Builder.CreateStore(Row, RowPos);
+ Builder.CreateStore(Col, ColPos);
+ }
+}
+
+void X86PreAMXConfig::addTileConfig(Instruction *ModelStart,
+ SmallVector<Value *, 8> &Shapes) {
+ Module *M = F.getParent();
+ IRBuilder<> Builder(ModelStart);
+ const DataLayout &DL = M->getDataLayout();
+ unsigned AddrSpace = DL.getAllocaAddrSpace();
+ LLVMContext &Ctx = Builder.getContext();
+ Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false);
+ Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx));
+
+ AllocaInst *Addr =
+ new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front());
+ Addr->setAlignment(Alignment);
+ Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy());
+
+ Builder.CreateAlignedStore(Constant::getNullValue(V512Ty), Addr, Alignment);
+
+ preWriteTileCfg(I8Ptr, Builder, Shapes);
+
+ Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, std::nullopt,
+ {I8Ptr});
+}
+
+// Todo: We may need to handle "more than one store" case in the future.
+bool X86PreAMXConfig::checkVolatileModel(SmallSet<Value *, 4> &Loads,
+ IntrinsicInst *Store,
+ IntrinsicInst *KeyAMX) {
+ Value *ST = Store->getOperand(4);
+
+ // Only has tileload and tilestore.
+ if (!KeyAMX)
+ return (Loads.size() == 1) && Loads.contains(ST);
+
+ // All Loads should be operands of KeyAMX.
+ // All tile operands of KeyAMX should come from Loads.
+ for (Value *Op : KeyAMX->operands()) {
+ if (Op->getType()->isX86_AMXTy())
+ if (!Loads.erase(Op))
+ return false;
+ }
+
+ // The def of KeyAMX should be stored into mem.
+ // Todo: is it key amx can be no def?
+ return Loads.empty() && (ST == cast<Value>(KeyAMX));
+}
+
+bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX,
+ SmallVector<Value *, 8> &Shapes) {
+ for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) {
+ Value *Op = KeyAMX->getOperand(I);
+ if (!Op->getType()->isX86_AMXTy())
+ continue;
+ IntrinsicInst *TileDef = dyn_cast<IntrinsicInst>(Op);
+ assert((TileDef && isTileLoad(TileDef)) &&
+ "All KeyAMX's tile definiation should comes from TileLoad!");
+ Shapes.push_back(TileDef->getOperand(0));
+ Shapes.push_back(TileDef->getOperand(1));
+ }
+ if (!isTileStore(KeyAMX)) {
+ Shapes.push_back(KeyAMX->getOperand(0));
+ Shapes.push_back(KeyAMX->getOperand(1));
+ }
+ return Shapes.size() != 0;
+}
+
+// Collect the shapes and skip the area of current key amx intrinsic.
+//
+// For example:
+// ...
+// --------------------------------------------------------------------------
+// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k)
+// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k)
+// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k)
+// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3)
+// call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k)
+// --------------------------------------------------------------------------
+BasicBlock::iterator
+X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter,
+ SmallVector<Value *, 8> &Shapes) {
+ IntrinsicInst *KeyAMX = nullptr;
+ BasicBlock *BB = Iter->getParent();
+ BasicBlock::iterator PosEnd = BB->end();
+ SmallSet<Value *, 4> Loads;
+
+ // See TileStore as "Config Position End" and check volatile model.
+ for (auto I = Iter, E = BB->end(); I != E; ++I) {
+ assert(!brokenVolatile(&*I) && "Not reach tile store!");
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
+ if (!II || !isAMXIntrinsic(II))
+ continue;
+
+ if (isTileLoad(II)) {
+ Loads.insert(II);
+ } else if (isTileStore(II)) {
+ if (!checkVolatileModel(Loads, II, KeyAMX))
+ report_fatal_error("Not Volatile AMX Model!");
+ PosEnd = I;
+ break;
+ } else {
+ assert(!KeyAMX && "Too many key amx intrinsic!");
+ KeyAMX = II;
+ }
+ }
+ assert(PosEnd != BB->end() && "Not find TileStore!");
+
+ // See KeyAMX as TileStore if only TileLoad and TileStore.
+ if (!KeyAMX)
+ KeyAMX = dyn_cast<IntrinsicInst>(&*PosEnd);
+
+ // Get Shapes in order.
+ assert(Shapes.empty() && "Shapes should be clean.");
+ getKeyAMXShapes(KeyAMX, Shapes);
+
+ return PosEnd;
+}
+
+// Record a key amx area's shapes with its position.
+// Use the first tileload as its position.
+// For example:
+// ...
+// --------------------------------------------------------------------------
+// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos
+// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) /
+// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes:
+// %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n)
+// call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n)
+// --------------------------------------------------------------------------
+bool X86PreAMXConfig::findConfigShapes(PosAndShapesMap &PosAndShapes) {
+ bool Find = false;
+ for (BasicBlock &BB : F) {
+ for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&*I);
+ if (!II)
+ continue;
+ if (!isAMXIntrinsic(II))
+ continue;
+ assert(onlyTileDef(II) && "Not volatile model for AMX at O0!");
+
+ I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]);
+ Find = true;
+ }
+ }
+ return Find;
+}
+
+// Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic.
+// e.g. (key amx = tdpbssd)
+// --------------------------------------------------------------------------
+// %cfgmem = alloca <16 x i32>, align 4 * allocate mem
+// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init
+// ...
+// ... pre-config shape of %t1 *
+// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 *
+// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config
+// ... *
+// ... pre-config shape of %t2 *
+// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes
+// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 *
+// ... *
+// ... pre-config shape of %t3 * of
+// store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 *
+// store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 *
+// ... * tiles
+// ... pre-config shape of %td *
+// store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 *
+// store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 *
+//
+// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config
+// --------------------------------------------------------------------------
+// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key
+// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...)
+// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx
+// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3)
+// call void @llvm.x86.tilestored64.internal(... td) area
+// --------------------------------------------------------------------------
+bool X86PreAMXConfig::preTileConfig() {
+ PosAndShapesMap PosAndShapes;
+ bool NeedCfg = findConfigShapes(PosAndShapes);
+ if (!NeedCfg)
+ return false;
+ for (auto &IPAndShapes : PosAndShapes)
+ addTileConfig(IPAndShapes.first, IPAndShapes.second);
+
+ return true;
+}
+} // anonymous namespace
+
+namespace {
+
+class X86PreAMXConfigPass : public FunctionPass {
+public:
+ static char ID;
+
+ X86PreAMXConfigPass() : FunctionPass(ID) {
+ initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+ bool C = false;
+
+ // Prepare for fast register allocation at O0.
+ if (TM->getOptLevel() == CodeGenOpt::None) {
+
+ // We pre-config each key AMX intrinsic at O0.
+ // In theory, one tile config can cover several AMX intrinsics, but
+ // it is very diffcult to classify the tile shapes at O0. So here we
+ // let thing be easy, pre-config every key AMX intrinsic.
+ X86PreAMXConfig PCFG(F);
+ C = PCFG.preTileConfig();
+ }
+
+ return C;
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.setPreservesCFG();
+ AU.addRequired<TargetPassConfig>();
+ }
+};
+
+} // anonymous namespace
+
+static const char PassName[] = "Pre AMX Tile Config";
+char X86PreAMXConfigPass::ID = 0;
+INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false)
+
+FunctionPass *llvm::createX86PreAMXConfigPass() {
+ return new X86PreAMXConfigPass();
+}