aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm12/lib/Target/X86/X86PreTileConfig.cpp
blob: b2f6d0604d1ab1c86ff3e3a50f87602a1c09bdb8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
//===-- X86PreTileConfig.cpp - Tile Register Configure---------------------===// 
// 
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 
// See https://llvm.org/LICENSE.txt for license information. 
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 
// 
//===----------------------------------------------------------------------===// 
// 
/// \file Pass to pre-config the shape of AMX register 
/// AMX register need to be configured before use. The shape of AMX register 
/// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions. 
/// The pldtilecfg is to config tile registers. It should dominator all AMX 
/// instructions. The pldtilecfg produce a virtual cfg register and the cfg 
/// register is used by all AMX instructions. 
/// This pass is to find the common dominator of all AMX instructions and 
/// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg 
/// produces is inserted as the last operand of each AMX instruction. We use 
/// this scheme to model the def-use relationship between AMX config instruction 
/// and other AMX instructions. Below is an example. 
/// 
///                        ----B1---- 
///                       /           \ 
///                      /             \ 
///                    B2               B3 
///    %1:tile = PTILELOADDV        %2:tile = PTILELOADDV 
/// 
///  is transformed to 
/// 
///                            B1 
///                 %25:tilecfg = PLDTILECFG 
///                       /           \ 
///                      /             \ 
///  %1:tile = PTILELOADDV %25    %2:tile = PTILELOADDV %25 
// 
//===----------------------------------------------------------------------===// 
 
#include "X86.h" 
#include "X86InstrBuilder.h" 
#include "X86RegisterInfo.h" 
#include "X86Subtarget.h" 
#include "llvm/CodeGen/MachineDominators.h" 
#include "llvm/CodeGen/MachineFunctionPass.h" 
#include "llvm/CodeGen/MachineInstr.h" 
#include "llvm/CodeGen/MachineRegisterInfo.h" 
#include "llvm/CodeGen/Passes.h" 
#include "llvm/CodeGen/TargetInstrInfo.h" 
#include "llvm/CodeGen/TargetRegisterInfo.h" 
#include "llvm/CodeGen/TileShapeInfo.h" 
#include "llvm/InitializePasses.h" 
 
using namespace llvm; 
 
#define DEBUG_TYPE "tile-pre-config" 
 
namespace { 
 
class X86PreTileConfig : public MachineFunctionPass { 
  // context 
  MachineFunction *MF = nullptr; 
  const X86Subtarget *ST = nullptr; 
  const TargetRegisterInfo *TRI; 
  const TargetInstrInfo *TII; 
  MachineDominatorTree *DomTree = nullptr; 
  MachineRegisterInfo *MRI = nullptr; 
 
  MachineInstr *getTileConfigPoint(); 
 
public: 
  X86PreTileConfig() : MachineFunctionPass(ID) {} 
 
  /// Return the pass name. 
  StringRef getPassName() const override { 
    return "Tile Register Pre-configure"; 
  } 
 
  /// X86PreTileConfig analysis usage. 
  void getAnalysisUsage(AnalysisUsage &AU) const override; 
 
  /// Perform register allocation. 
  bool runOnMachineFunction(MachineFunction &mf) override; 
 
  static char ID; 
}; 
 
} // end anonymous namespace 
 
char X86PreTileConfig::ID = 0; 
 
INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", 
                      "Tile Register Configure", false, false) 
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 
INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", 
                    "Tile Register Configure", false, false) 
 
void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const { 
  AU.setPreservesAll(); 
  AU.addRequired<MachineDominatorTree>(); 
  MachineFunctionPass::getAnalysisUsage(AU); 
} 
 
static Register buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx, 
                              const TargetInstrInfo *TII, 
                              MachineRegisterInfo *MRI, 
                              const X86Subtarget *ST) { 
  auto *MBB = MI->getParent(); 
 
  // FIXME: AMX should assume AVX512 enabled. 
  if (ST->hasAVX512()) { 
    // Zero stack slot. 
    Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); 
    BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm) 
        .addReg(Zmm, RegState::Undef) 
        .addReg(Zmm, RegState::Undef); 
    addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)), 
                      FrameIdx) 
        .addReg(Zmm); 
  } 
 
  // build psuedo ldtilecfg 
  Register VReg = MRI->createVirtualRegister(&X86::TILECFGRegClass); 
 
  addFrameReference( 
      BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PLDTILECFG), VReg), FrameIdx); 
 
  return VReg; 
} 
 
static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { 
  unsigned Opcode = MI.getOpcode(); 
  switch (Opcode) { 
  default: 
    llvm_unreachable("Unexpected machine instruction on tile"); 
  case X86::PTILELOADDV: 
  case X86::PTDPBSSDV: 
  case X86::PTILEZEROV: 
    MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1)); 
    MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2)); 
    ShapeT Shape(&MO1, &MO2, MRI); 
    return Shape; 
  } 
} 
 
MachineInstr *X86PreTileConfig::getTileConfigPoint() { 
  DenseMap<Register, ShapeT> PhysShapeInfo; 
  MachineBasicBlock *MBB = nullptr; 
  DenseSet<const MachineInstr *> MIs; 
  for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { 
    Register VirtReg = Register::index2VirtReg(i); 
    if (MRI->reg_nodbg_empty(VirtReg)) 
      continue; 
    const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); 
    if (RC.getID() != X86::TILERegClassID) 
      continue; 
 
    // Find the common dominator for all MI that define tile register. 
    for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { 
      if (MO.isUndef()) 
        continue; 
      const auto *MI = MO.getParent(); 
      // PHI or IMPLICIT_DEF instructiion. 
      // There must be a input tile before PHI instruction. 
      if (MI->isTransient()) 
        continue; 
      if (!MBB) 
        MBB = const_cast<MachineBasicBlock *>(MI->getParent()); 
      MBB = DomTree->findNearestCommonDominator( 
          MBB, const_cast<MachineBasicBlock *>(MI->getParent())); 
 
      // Collect the instructions that define shape. 
      ShapeT Shape = getShape(*MI, MRI); 
      std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(), 
                                                  Shape.getCol()}; 
      for (auto *ShapeMO : ShapeMOs) { 
        Register ShapeReg = ShapeMO->getReg(); 
        for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { 
          const auto *ShapeMI = MO.getParent(); 
          MIs.insert(ShapeMI); 
        } 
      } 
    } 
  } 
  if (!MBB) 
    return nullptr; 
  // This pass is before the pass of eliminating PHI node, so it 
  // is in SSA form. 
  assert(MRI->isSSA() && "Not SSA form in pre-tile config"); 
  // Shape def should dominate tile config MBB. 
  //    def s           s1    s2 
  //     / \             \   / 
  //    /   \             \ / 
  //  conf               s3=phi(s1,s2) 
  //                       | 
  //                       c 
  // 
  for (const auto *MI : MIs) { 
    const MachineBasicBlock *ShapeMBB = MI->getParent(); 
    if (DomTree->dominates(ShapeMBB, MBB)) 
      continue; 
    if (MI->isMoveImmediate()) 
      continue; 
    report_fatal_error(MF->getName() + ": Failed to config tile register, " 
                                       "please define the shape earlier"); 
  } 
 
  // ldtilecfg should be inserted after the MI that define the shape. 
  MachineBasicBlock::reverse_instr_iterator I, E; 
  for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { 
    auto *MI = &*I; 
    if (MIs.count(MI) && (!MI->isMoveImmediate())) 
      break; 
  } 
  MachineBasicBlock::iterator MII; 
  if (I == E) 
    MII = MBB->getFirstNonPHI(); 
  else { 
    MII = MachineBasicBlock::iterator(&*I); 
    MII++; 
  } 
  return &*MII; 
} 
 
static void addTileCFGUse(MachineFunction &MF, Register CFG) { 
  for (MachineBasicBlock &MBB : MF) { 
 
    // Traverse the basic block. 
    for (MachineInstr &MI : MBB) { 
      unsigned Opcode = MI.getOpcode(); 
      switch (Opcode) { 
      default: 
        break; 
      case X86::PTILELOADDV: 
      case X86::PTILESTOREDV: 
      case X86::PTDPBSSDV: 
      case X86::PTILEZEROV: 
        unsigned NumOperands = MI.getNumOperands(); 
        MI.RemoveOperand(NumOperands - 1); 
        MI.addOperand(MF, MachineOperand::CreateReg(CFG, false)); 
        break; 
      } 
    } 
  } 
} 
 
bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) { 
  MF = &mf; 
  MRI = &mf.getRegInfo(); 
  ST = &mf.getSubtarget<X86Subtarget>(); 
  TRI = ST->getRegisterInfo(); 
  TII = mf.getSubtarget().getInstrInfo(); 
  DomTree = &getAnalysis<MachineDominatorTree>(); 
 
  MachineInstr *MI = getTileConfigPoint(); 
  if (!MI) 
    return false; 
  unsigned Size = ST->getTileConfigSize(); 
  Align Alignment = ST->getTileConfigAlignment(); 
  int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false); 
  Register CFG = buildConfigMI(MI, SS, TII, MRI, ST); 
  addTileCFGUse(mf, CFG); 
  return true; 
} 
 
FunctionPass *llvm::createX86PreTileConfigPass() { 
  return new X86PreTileConfig(); 
}