aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm12/include/llvm/CodeGen/TileShapeInfo.h
blob: 083a3bd4e903f38ae9c632efbe4987e659d72b69 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#pragma once 
 
#ifdef __GNUC__ 
#pragma GCC diagnostic push 
#pragma GCC diagnostic ignored "-Wunused-parameter" 
#endif 
 
//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// 
// 
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 
// See https://llvm.org/LICENSE.txt for license information. 
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 
// 
//===----------------------------------------------------------------------===// 
// 
/// \file Shape utility for AMX. 
/// AMX hardware requires to config the shape of tile data register before use. 
/// The 2D shape includes row and column. In AMX intrinsics interface the shape 
/// is passed as 1st and 2nd parameter and they are lowered as the 1st and 2nd 
/// machine operand of AMX pseudo instructions. ShapeT class is to facilitate 
/// tile config and register allocator. The row and column are machine operand 
/// of AMX pseudo instructions. 
// 
//===----------------------------------------------------------------------===// 
 
#ifndef LLVM_CODEGEN_TILESHAPEINFO_H 
#define LLVM_CODEGEN_TILESHAPEINFO_H 
 
#include "llvm/ADT/DenseMapInfo.h" 
#include "llvm/CodeGen/MachineInstr.h" 
#include "llvm/CodeGen/MachineOperand.h" 
#include "llvm/CodeGen/MachineRegisterInfo.h" 
#include "llvm/CodeGen/Register.h" 
#include <utility> 
 
namespace llvm { 
 
class ShapeT { 
public: 
  ShapeT(MachineOperand *Row, MachineOperand *Col, 
         const MachineRegisterInfo *MRI = nullptr) 
      : Row(Row), Col(Col) { 
    if (MRI) 
      deduceImm(MRI); 
  } 
  ShapeT() 
      : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), 
        ColImm(InvalidImmShape) {} 
  bool operator==(const ShapeT &Shape) { 
    MachineOperand *R = Shape.Row; 
    MachineOperand *C = Shape.Col; 
    if (!R || !C) 
      return false; 
    if (!Row || !Col) 
      return false; 
    if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) 
      return true; 
    if ((RowImm != InvalidImmShape) && (ColImm != InvalidImmShape)) 
      return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); 
    return false; 
  } 
 
  bool operator!=(const ShapeT &Shape) { return !(*this == Shape); } 
 
  MachineOperand *getRow() const { return Row; } 
 
  MachineOperand *getCol() const { return Col; } 
 
  int64_t getRowImm() const { return RowImm; } 
 
  int64_t getColImm() const { return ColImm; } 
 
  bool isValid() { return (Row != nullptr) && (Col != nullptr); } 
 
  void deduceImm(const MachineRegisterInfo *MRI) { 
    // All def must be the same value, otherwise it is invalid MIs. 
    // Find the immediate. 
    // TODO copy propagation. 
    auto GetImm = [&](Register Reg) { 
      int64_t Imm = InvalidImmShape; 
      for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { 
        const auto *MI = DefMO.getParent(); 
        if (MI->isMoveImmediate()) { 
          Imm = MI->getOperand(1).getImm(); 
          break; 
        } 
      } 
      return Imm; 
    }; 
    RowImm = GetImm(Row->getReg()); 
    ColImm = GetImm(Col->getReg()); 
  } 
 
private: 
  static constexpr int64_t InvalidImmShape = -1; 
  MachineOperand *Row; 
  MachineOperand *Col; 
  int64_t RowImm; 
  int64_t ColImm; 
}; 
 
} // namespace llvm 
 
#endif 
 
#ifdef __GNUC__ 
#pragma GCC diagnostic pop 
#endif