blob: 5f935511e0c6eafd6fa53a16164a13aa7edd03c2 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
#pragma once
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#endif
//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
/// \file Shape utility for AMX.
/// AMX hardware requires to config the shape of tile data register before use.
/// The 2D shape includes row and column. In AMX intrinsics interface the shape
/// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd
/// machine operand of AMX pseudo instructions. ShapeT class is to facilitate
/// tile config and register allocator. The row and column are machine operand
/// of AMX pseudo instructions.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_CODEGEN_TILESHAPEINFO_H
#define LLVM_CODEGEN_TILESHAPEINFO_H
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/Register.h"
namespace llvm {
class ShapeT {
public:
ShapeT(MachineOperand *Row, MachineOperand *Col,
const MachineRegisterInfo *MRI = nullptr)
: Row(Row), Col(Col) {
if (MRI)
deduceImm(MRI);
}
ShapeT()
: Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
ColImm(InvalidImmShape) {}
bool operator==(const ShapeT &Shape) {
MachineOperand *R = Shape.Row;
MachineOperand *C = Shape.Col;
if (!R || !C)
return false;
if (!Row || !Col)
return false;
if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
return true;
if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape))
return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
return false;
}
bool operator!=(const ShapeT &Shape) { return !(*this == Shape); }
MachineOperand *getRow() const { return Row; }
MachineOperand *getCol() const { return Col; }
int64_t getRowImm() const { return RowImm; }
int64_t getColImm() const { return ColImm; }
bool isValid() { return (Row != nullptr) && (Col != nullptr); }
void deduceImm(const MachineRegisterInfo *MRI) {
// All def must be the same value, otherwise it is invalid MIs.
// Find the immediate.
// TODO copy propagation.
auto GetImm = [&](Register Reg) {
int64_t Imm = InvalidImmShape;
for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
const auto *MI = DefMO.getParent();
if (MI->isMoveImmediate()) {
Imm = MI->getOperand(1).getImm();
break;
}
}
return Imm;
};
RowImm = GetImm(Row->getReg());
ColImm = GetImm(Col->getReg());
}
private:
static constexpr int64_t InvalidImmShape = -1;
MachineOperand *Row;
MachineOperand *Col;
int64_t RowImm;
int64_t ColImm;
};
} // namespace llvm
#endif
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
|