aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
diff options
context:
space:
mode:
authororivej <orivej@yandex-team.ru>2022-02-10 16:44:49 +0300
committerDaniil Cherednik <dcherednik@yandex-team.ru>2022-02-10 16:44:49 +0300
commit718c552901d703c502ccbefdfc3c9028d608b947 (patch)
tree46534a98bbefcd7b1f3faa5b52c138ab27db75b7 /contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
parente9656aae26e0358d5378e5b63dcac5c8dbe0e4d0 (diff)
downloadydb-718c552901d703c502ccbefdfc3c9028d608b947.tar.gz
Restoring authorship annotation for <orivej@yandex-team.ru>. Commit 1 of 2.
Diffstat (limited to 'contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp')
-rw-r--r--contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp3726
1 files changed, 1863 insertions, 1863 deletions
diff --git a/contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp b/contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
index 8e251ca940..78e926254e 100644
--- a/contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ b/contrib/libs/llvm12/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -1,1214 +1,1214 @@
-//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-//
-// Lower matrix intrinsics to vector operations.
-//
-// TODO:
-// * Improve fusion:
-// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
-// transposed.
-// * Improve cost-modeling, e.g. choose different number of rows/columns
-// columns for tiles, consider cost of copies on alias.
-//
-//===----------------------------------------------------------------------===//
-
-#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
-#include "llvm/ADT/GraphTraits.h"
-#include "llvm/ADT/PostOrderIterator.h"
-#include "llvm/ADT/SmallVector.h"
-#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/DomTreeUpdater.h"
-#include "llvm/Analysis/OptimizationRemarkEmitter.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/Analysis/ValueTracking.h"
-#include "llvm/Analysis/VectorUtils.h"
-#include "llvm/IR/CFG.h"
-#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/DebugInfoMetadata.h"
-#include "llvm/IR/Function.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Instructions.h"
-#include "llvm/IR/IntrinsicInst.h"
-#include "llvm/IR/PatternMatch.h"
-#include "llvm/InitializePasses.h"
-#include "llvm/Pass.h"
-#include "llvm/Support/Alignment.h"
-#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
-#include "llvm/Transforms/Scalar.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+//===- LowerMatrixIntrinsics.cpp - Lower matrix intrinsics -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Lower matrix intrinsics to vector operations.
+//
+// TODO:
+// * Improve fusion:
+// * Support more cases, e.g. multiply-add, multiply-sub, operands/results
+// transposed.
+// * Improve cost-modeling, e.g. choose different number of rows/columns
+// columns for tiles, consider cost of copies on alias.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Scalar/LowerMatrixIntrinsics.h"
+#include "llvm/ADT/GraphTraits.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Alignment.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/LoopUtils.h"
#include "llvm/Transforms/Utils/MatrixUtils.h"
-
-using namespace llvm;
-using namespace PatternMatch;
-
-#define DEBUG_TYPE "lower-matrix-intrinsics"
-
-static cl::opt<bool> EnableShapePropagation(
- "matrix-propagate-shape", cl::init(true), cl::Hidden,
- cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
- "instructions."));
-
-static cl::opt<bool>
- FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
- cl::desc("Enable/disable fusing matrix instructions."));
-// TODO: Allow and use non-square tiles.
-static cl::opt<unsigned> TileSize(
- "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
- cl::desc(
- "Tile size for matrix instruction fusion using square-shaped tiles."));
+
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "lower-matrix-intrinsics"
+
+static cl::opt<bool> EnableShapePropagation(
+ "matrix-propagate-shape", cl::init(true), cl::Hidden,
+ cl::desc("Enable/disable shape propagation from matrix intrinsics to other "
+ "instructions."));
+
+static cl::opt<bool>
+ FuseMatrix("fuse-matrix", cl::init(true), cl::Hidden,
+ cl::desc("Enable/disable fusing matrix instructions."));
+// TODO: Allow and use non-square tiles.
+static cl::opt<unsigned> TileSize(
+ "fuse-matrix-tile-size", cl::init(4), cl::Hidden,
+ cl::desc(
+ "Tile size for matrix instruction fusion using square-shaped tiles."));
static cl::opt<bool> TileUseLoops("fuse-matrix-use-loops", cl::init(false),
cl::Hidden,
cl::desc("Generate loop nest for tiling."));
-static cl::opt<bool> ForceFusion(
- "force-fuse-matrix", cl::init(false), cl::Hidden,
- cl::desc("Force matrix instruction fusion even if not profitable."));
-static cl::opt<bool> AllowContractEnabled(
- "matrix-allow-contract", cl::init(false), cl::Hidden,
- cl::desc("Allow the use of FMAs if available and profitable. This may "
- "result in different results, due to less rounding error."));
-
-enum class MatrixLayoutTy { ColumnMajor, RowMajor };
-
-static cl::opt<MatrixLayoutTy> MatrixLayout(
- "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
- cl::desc("Sets the default matrix layout"),
- cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
- "Use column-major layout"),
- clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
- "Use row-major layout")));
-
-/// Helper function to either return Scope, if it is a subprogram or the
-/// attached subprogram for a local scope.
-static DISubprogram *getSubprogram(DIScope *Scope) {
- if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
- return Subprogram;
- return cast<DILocalScope>(Scope)->getSubprogram();
-}
-
-namespace {
-
-// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
-// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
-// assuming \p Stride elements between start two consecutive vectors.
-// \p Stride must be >= \p NumElements.
-// For column-major matrixes, the function computes the address of a column
-// vectors and \p NumElements must be set to the number of elements in a column
-// (= number of rows of the matrix). For row-major matrixes, the function
-// computes the address of a row vector and \p NumElements must be set to the
-// number of elements in a column (= number of columns of the matrix).
-//
-// Consider a 4x4 matrix in column-mjaor layout like below
-//
-// 0 1 2 3
-// 0 v_0_0 v_0_1 v_0_2 v_0_3
-// 1 v_1_0 v_1_1 v_1_2 v_1_3
-// 2 v_2_0 v_2_1 v_2_2 v_2_3
-// 3 v_3_0 v_3_1 v_3_2 v_3_3
-
-// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
-// we need a pointer to the first element of the submatrix as base pointer.
-// Then we can use computeVectorAddr to compute the addresses for the columns
-// of the sub-matrix.
-//
-// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
-// -> just returns Base
-// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
-// -> returns Base + (1 * 4)
-// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
-// -> returns Base + (2 * 4)
-//
-// The graphic below illustrates the number of elements in a column (marked
-// with |) and the number of skipped elements (marked with }).
-//
-// v_0_0 v_0_1 {v_0_2 {v_0_3
-// Base Col 1 Col 2
-// | | |
-// v_1_0 |v_1_1 |v_1_2 |v_1_3
-// v_2_0 |v_2_1 |v_2_2 |v_2_3
-// v_3_0 {v_3_1 {v_3_2 v_3_3
-//
-Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
- unsigned NumElements, Type *EltType,
- IRBuilder<> &Builder) {
-
- assert((!isa<ConstantInt>(Stride) ||
- cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
- "Stride must be >= the number of elements in the result vector.");
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
-
- // Compute the start of the vector with index VecIdx as VecIdx * Stride.
- Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
-
- // Get pointer to the start of the selected vector. Skip GEP creation,
- // if we select vector 0.
- if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
- VecStart = BasePtr;
- else
- VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
-
- // Cast elementwise vector start pointer to a pointer to a vector
- // (EltType x NumElements)*.
- auto *VecType = FixedVectorType::get(EltType, NumElements);
- Type *VecPtrType = PointerType::get(VecType, AS);
- return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
-}
-
-/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
-///
-/// Currently, the lowering for each matrix intrinsic is done as follows:
-/// 1. Propagate the shape information from intrinsics to connected
-/// instructions.
-/// 2. Lower instructions with shape information (assuming column-major layout).
-/// The lowering works similarly using row-major layout.
-/// 2.1. Get column vectors for each argument. If we already lowered the
-/// definition of an argument, use the produced column vectors directly.
-/// If not, split the operand vector containing an embedded matrix into
-/// a set of column vectors,
-/// 2.2. Lower the instruction in terms of column major operations, which
-/// yields a set of column vectors containing result matrix. Note that we
-/// lower all instructions that have shape information. Besides the
-/// intrinsics, this includes stores for example.
-/// 2.3. Update uses of the lowered instruction. If we have shape information
-/// for a user, there is nothing to do, as we will look up the result
-/// column matrix when lowering the user. For other uses, we embed the
-/// result matrix in a flat vector and update the use.
-/// 2.4. Cache the result column matrix for the instruction we lowered
-/// 3. After we lowered all instructions in a function, remove the now
-/// obsolete instructions.
-///
-class LowerMatrixIntrinsics {
- Function &Func;
- const DataLayout &DL;
- const TargetTransformInfo &TTI;
+static cl::opt<bool> ForceFusion(
+ "force-fuse-matrix", cl::init(false), cl::Hidden,
+ cl::desc("Force matrix instruction fusion even if not profitable."));
+static cl::opt<bool> AllowContractEnabled(
+ "matrix-allow-contract", cl::init(false), cl::Hidden,
+ cl::desc("Allow the use of FMAs if available and profitable. This may "
+ "result in different results, due to less rounding error."));
+
+enum class MatrixLayoutTy { ColumnMajor, RowMajor };
+
+static cl::opt<MatrixLayoutTy> MatrixLayout(
+ "matrix-default-layout", cl::init(MatrixLayoutTy::ColumnMajor),
+ cl::desc("Sets the default matrix layout"),
+ cl::values(clEnumValN(MatrixLayoutTy::ColumnMajor, "column-major",
+ "Use column-major layout"),
+ clEnumValN(MatrixLayoutTy::RowMajor, "row-major",
+ "Use row-major layout")));
+
+/// Helper function to either return Scope, if it is a subprogram or the
+/// attached subprogram for a local scope.
+static DISubprogram *getSubprogram(DIScope *Scope) {
+ if (auto *Subprogram = dyn_cast<DISubprogram>(Scope))
+ return Subprogram;
+ return cast<DILocalScope>(Scope)->getSubprogram();
+}
+
+namespace {
+
+// Given an element pointer \p BasePtr to the start of a (sub) matrix, compute
+// the start address of vector \p VecIdx with type (\p EltType x \p NumElements)
+// assuming \p Stride elements between start two consecutive vectors.
+// \p Stride must be >= \p NumElements.
+// For column-major matrixes, the function computes the address of a column
+// vectors and \p NumElements must be set to the number of elements in a column
+// (= number of rows of the matrix). For row-major matrixes, the function
+// computes the address of a row vector and \p NumElements must be set to the
+// number of elements in a column (= number of columns of the matrix).
+//
+// Consider a 4x4 matrix in column-mjaor layout like below
+//
+// 0 1 2 3
+// 0 v_0_0 v_0_1 v_0_2 v_0_3
+// 1 v_1_0 v_1_1 v_1_2 v_1_3
+// 2 v_2_0 v_2_1 v_2_2 v_2_3
+// 3 v_3_0 v_3_1 v_3_2 v_3_3
+
+// To compute the column addresses for a 2x3 sub-matrix at row 1 and column 1,
+// we need a pointer to the first element of the submatrix as base pointer.
+// Then we can use computeVectorAddr to compute the addresses for the columns
+// of the sub-matrix.
+//
+// Column 0: computeVectorAddr(Base, 0 (column), 4 (stride), 2 (num rows), ..)
+// -> just returns Base
+// Column 1: computeVectorAddr(Base, 1 (column), 4 (stride), 2 (num rows), ..)
+// -> returns Base + (1 * 4)
+// Column 2: computeVectorAddr(Base, 2 (column), 4 (stride), 2 (num rows), ..)
+// -> returns Base + (2 * 4)
+//
+// The graphic below illustrates the number of elements in a column (marked
+// with |) and the number of skipped elements (marked with }).
+//
+// v_0_0 v_0_1 {v_0_2 {v_0_3
+// Base Col 1 Col 2
+// | | |
+// v_1_0 |v_1_1 |v_1_2 |v_1_3
+// v_2_0 |v_2_1 |v_2_2 |v_2_3
+// v_3_0 {v_3_1 {v_3_2 v_3_3
+//
+Value *computeVectorAddr(Value *BasePtr, Value *VecIdx, Value *Stride,
+ unsigned NumElements, Type *EltType,
+ IRBuilder<> &Builder) {
+
+ assert((!isa<ConstantInt>(Stride) ||
+ cast<ConstantInt>(Stride)->getZExtValue() >= NumElements) &&
+ "Stride must be >= the number of elements in the result vector.");
+ unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
+
+ // Compute the start of the vector with index VecIdx as VecIdx * Stride.
+ Value *VecStart = Builder.CreateMul(VecIdx, Stride, "vec.start");
+
+ // Get pointer to the start of the selected vector. Skip GEP creation,
+ // if we select vector 0.
+ if (isa<ConstantInt>(VecStart) && cast<ConstantInt>(VecStart)->isZero())
+ VecStart = BasePtr;
+ else
+ VecStart = Builder.CreateGEP(EltType, BasePtr, VecStart, "vec.gep");
+
+ // Cast elementwise vector start pointer to a pointer to a vector
+ // (EltType x NumElements)*.
+ auto *VecType = FixedVectorType::get(EltType, NumElements);
+ Type *VecPtrType = PointerType::get(VecType, AS);
+ return Builder.CreatePointerCast(VecStart, VecPtrType, "vec.cast");
+}
+
+/// LowerMatrixIntrinsics contains the methods used to lower matrix intrinsics.
+///
+/// Currently, the lowering for each matrix intrinsic is done as follows:
+/// 1. Propagate the shape information from intrinsics to connected
+/// instructions.
+/// 2. Lower instructions with shape information (assuming column-major layout).
+/// The lowering works similarly using row-major layout.
+/// 2.1. Get column vectors for each argument. If we already lowered the
+/// definition of an argument, use the produced column vectors directly.
+/// If not, split the operand vector containing an embedded matrix into
+/// a set of column vectors,
+/// 2.2. Lower the instruction in terms of column major operations, which
+/// yields a set of column vectors containing result matrix. Note that we
+/// lower all instructions that have shape information. Besides the
+/// intrinsics, this includes stores for example.
+/// 2.3. Update uses of the lowered instruction. If we have shape information
+/// for a user, there is nothing to do, as we will look up the result
+/// column matrix when lowering the user. For other uses, we embed the
+/// result matrix in a flat vector and update the use.
+/// 2.4. Cache the result column matrix for the instruction we lowered
+/// 3. After we lowered all instructions in a function, remove the now
+/// obsolete instructions.
+///
+class LowerMatrixIntrinsics {
+ Function &Func;
+ const DataLayout &DL;
+ const TargetTransformInfo &TTI;
AliasAnalysis *AA;
DominatorTree *DT;
LoopInfo *LI;
OptimizationRemarkEmitter *ORE;
-
- /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
- struct OpInfoTy {
- /// Number of stores emitted to generate this matrix.
- unsigned NumStores = 0;
- /// Number of loads emitted to generate this matrix.
- unsigned NumLoads = 0;
- /// Number of compute operations emitted to generate this matrix.
- unsigned NumComputeOps = 0;
-
- OpInfoTy &operator+=(const OpInfoTy &RHS) {
- NumStores += RHS.NumStores;
- NumLoads += RHS.NumLoads;
- NumComputeOps += RHS.NumComputeOps;
- return *this;
- }
- };
-
- /// Wrapper class representing a matrix as a set of vectors, either in row or
- /// column major layout. All vectors must have the same vector type.
- class MatrixTy {
- SmallVector<Value *, 16> Vectors;
-
- OpInfoTy OpInfo;
-
- bool IsColumnMajor = true;
-
- public:
- MatrixTy()
- : Vectors(),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
- MatrixTy(ArrayRef<Value *> Vectors)
- : Vectors(Vectors.begin(), Vectors.end()),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
- MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
- : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
-
- unsigned D = isColumnMajor() ? NumColumns : NumRows;
- for (unsigned J = 0; J < D; ++J)
- addVector(UndefValue::get(FixedVectorType::get(
- EltTy, isColumnMajor() ? NumRows : NumColumns)));
- }
-
- Value *getVector(unsigned i) const { return Vectors[i]; }
- Value *getColumn(unsigned i) const {
- assert(isColumnMajor() && "only supported for column-major matrixes");
- return Vectors[i];
- }
- Value *getRow(unsigned i) const {
- assert(!isColumnMajor() && "only supported for row-major matrixes");
- return Vectors[i];
- }
-
- void setVector(unsigned i, Value *V) { Vectors[i] = V; }
-
+
+ /// Contains estimates of the number of operations (loads, stores, compute) required to lower a matrix operation.
+ struct OpInfoTy {
+ /// Number of stores emitted to generate this matrix.
+ unsigned NumStores = 0;
+ /// Number of loads emitted to generate this matrix.
+ unsigned NumLoads = 0;
+ /// Number of compute operations emitted to generate this matrix.
+ unsigned NumComputeOps = 0;
+
+ OpInfoTy &operator+=(const OpInfoTy &RHS) {
+ NumStores += RHS.NumStores;
+ NumLoads += RHS.NumLoads;
+ NumComputeOps += RHS.NumComputeOps;
+ return *this;
+ }
+ };
+
+ /// Wrapper class representing a matrix as a set of vectors, either in row or
+ /// column major layout. All vectors must have the same vector type.
+ class MatrixTy {
+ SmallVector<Value *, 16> Vectors;
+
+ OpInfoTy OpInfo;
+
+ bool IsColumnMajor = true;
+
+ public:
+ MatrixTy()
+ : Vectors(),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+ MatrixTy(ArrayRef<Value *> Vectors)
+ : Vectors(Vectors.begin(), Vectors.end()),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+ MatrixTy(unsigned NumRows, unsigned NumColumns, Type *EltTy)
+ : IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {
+
+ unsigned D = isColumnMajor() ? NumColumns : NumRows;
+ for (unsigned J = 0; J < D; ++J)
+ addVector(UndefValue::get(FixedVectorType::get(
+ EltTy, isColumnMajor() ? NumRows : NumColumns)));
+ }
+
+ Value *getVector(unsigned i) const { return Vectors[i]; }
+ Value *getColumn(unsigned i) const {
+ assert(isColumnMajor() && "only supported for column-major matrixes");
+ return Vectors[i];
+ }
+ Value *getRow(unsigned i) const {
+ assert(!isColumnMajor() && "only supported for row-major matrixes");
+ return Vectors[i];
+ }
+
+ void setVector(unsigned i, Value *V) { Vectors[i] = V; }
+
Type *getElementType() const { return getVectorTy()->getElementType(); }
-
- unsigned getNumVectors() const {
- if (isColumnMajor())
- return getNumColumns();
- return getNumRows();
- }
-
- unsigned getNumColumns() const {
- if (isColumnMajor())
- return Vectors.size();
- else {
- assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
- }
- }
- unsigned getNumRows() const {
- if (isColumnMajor()) {
- assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
- return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
- } else
- return Vectors.size();
- }
-
- void addVector(Value *V) { Vectors.push_back(V); }
- VectorType *getColumnTy() {
- assert(isColumnMajor() && "only supported for column-major matrixes");
- return getVectorTy();
- }
-
+
+ unsigned getNumVectors() const {
+ if (isColumnMajor())
+ return getNumColumns();
+ return getNumRows();
+ }
+
+ unsigned getNumColumns() const {
+ if (isColumnMajor())
+ return Vectors.size();
+ else {
+ assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+ return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ }
+ }
+ unsigned getNumRows() const {
+ if (isColumnMajor()) {
+ assert(Vectors.size() > 0 && "Cannot call getNumRows without columns");
+ return cast<FixedVectorType>(Vectors[0]->getType())->getNumElements();
+ } else
+ return Vectors.size();
+ }
+
+ void addVector(Value *V) { Vectors.push_back(V); }
+ VectorType *getColumnTy() {
+ assert(isColumnMajor() && "only supported for column-major matrixes");
+ return getVectorTy();
+ }
+
VectorType *getVectorTy() const {
- return cast<VectorType>(Vectors[0]->getType());
- }
-
- iterator_range<SmallVector<Value *, 8>::iterator> columns() {
- assert(isColumnMajor() &&
- "columns() only supported for column-major matrixes");
- return make_range(Vectors.begin(), Vectors.end());
- }
-
- iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
- return make_range(Vectors.begin(), Vectors.end());
- }
-
- /// Embed the vectors of the matrix into a flat vector by concatenating
- /// them.
- Value *embedInVector(IRBuilder<> &Builder) const {
- return Vectors.size() == 1 ? Vectors[0]
- : concatenateVectors(Builder, Vectors);
- }
-
- MatrixTy &addNumLoads(unsigned N) {
- OpInfo.NumLoads += N;
- return *this;
- }
-
- void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
-
- MatrixTy &addNumStores(unsigned N) {
- OpInfo.NumStores += N;
- return *this;
- }
-
- MatrixTy &addNumComputeOps(unsigned N) {
- OpInfo.NumComputeOps += N;
- return *this;
- }
-
- unsigned getNumStores() const { return OpInfo.NumStores; }
- unsigned getNumLoads() const { return OpInfo.NumLoads; }
- unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
-
- const OpInfoTy &getOpInfo() const { return OpInfo; }
-
- bool isColumnMajor() const { return IsColumnMajor; }
-
- unsigned getStride() const {
- if (isColumnMajor())
- return getNumRows();
- return getNumColumns();
- }
-
- /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
- /// matrix is column-major, the result vector is extracted from a column
- /// vector, otherwise from a row vector.
- Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
- IRBuilder<> &Builder) const {
- Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
- return Builder.CreateShuffleVector(
+ return cast<VectorType>(Vectors[0]->getType());
+ }
+
+ iterator_range<SmallVector<Value *, 8>::iterator> columns() {
+ assert(isColumnMajor() &&
+ "columns() only supported for column-major matrixes");
+ return make_range(Vectors.begin(), Vectors.end());
+ }
+
+ iterator_range<SmallVector<Value *, 8>::iterator> vectors() {
+ return make_range(Vectors.begin(), Vectors.end());
+ }
+
+ /// Embed the vectors of the matrix into a flat vector by concatenating
+ /// them.
+ Value *embedInVector(IRBuilder<> &Builder) const {
+ return Vectors.size() == 1 ? Vectors[0]
+ : concatenateVectors(Builder, Vectors);
+ }
+
+ MatrixTy &addNumLoads(unsigned N) {
+ OpInfo.NumLoads += N;
+ return *this;
+ }
+
+ void setNumLoads(unsigned N) { OpInfo.NumLoads = N; }
+
+ MatrixTy &addNumStores(unsigned N) {
+ OpInfo.NumStores += N;
+ return *this;
+ }
+
+ MatrixTy &addNumComputeOps(unsigned N) {
+ OpInfo.NumComputeOps += N;
+ return *this;
+ }
+
+ unsigned getNumStores() const { return OpInfo.NumStores; }
+ unsigned getNumLoads() const { return OpInfo.NumLoads; }
+ unsigned getNumComputeOps() const { return OpInfo.NumComputeOps; }
+
+ const OpInfoTy &getOpInfo() const { return OpInfo; }
+
+ bool isColumnMajor() const { return IsColumnMajor; }
+
+ unsigned getStride() const {
+ if (isColumnMajor())
+ return getNumRows();
+ return getNumColumns();
+ }
+
+ /// Extract a vector of \p NumElts starting at index (\p I, \p J). If the
+ /// matrix is column-major, the result vector is extracted from a column
+ /// vector, otherwise from a row vector.
+ Value *extractVector(unsigned I, unsigned J, unsigned NumElts,
+ IRBuilder<> &Builder) const {
+ Value *Vec = isColumnMajor() ? getColumn(J) : getRow(I);
+ return Builder.CreateShuffleVector(
Vec, createSequentialMask(isColumnMajor() ? I : J, NumElts, 0),
- "block");
- }
- };
-
- struct ShapeInfo {
- unsigned NumRows;
- unsigned NumColumns;
-
- bool IsColumnMajor;
-
- ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
- : NumRows(NumRows), NumColumns(NumColumns),
- IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
-
- ShapeInfo(Value *NumRows, Value *NumColumns)
- : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
- cast<ConstantInt>(NumColumns)->getZExtValue()) {}
-
- bool operator==(const ShapeInfo &other) {
- return NumRows == other.NumRows && NumColumns == other.NumColumns;
- }
- bool operator!=(const ShapeInfo &other) { return !(*this == other); }
-
- /// Returns true if shape-information is defined, meaning both dimensions
- /// are != 0.
- operator bool() const {
- assert(NumRows == 0 || NumColumns != 0);
- return NumRows != 0;
- }
-
- unsigned getStride() const {
- if (IsColumnMajor)
- return NumRows;
- return NumColumns;
- }
-
- unsigned getNumVectors() const {
- if (IsColumnMajor)
- return NumColumns;
- return NumRows;
- }
- };
-
- /// Maps instructions to their shape information. The shape information
- /// describes the shape to be used while lowering. This matches the shape of
- /// the result value of the instruction, with the only exceptions being store
- /// instructions and the matrix_column_major_store intrinsics. For those, the
- /// shape information indicates that those instructions should be lowered
- /// using shape information as well.
- DenseMap<Value *, ShapeInfo> ShapeMap;
-
- /// List of instructions to remove. While lowering, we are not replacing all
- /// users of a lowered instruction, if shape information is available and
- /// those need to be removed after we finished lowering.
- SmallVector<Instruction *, 16> ToRemove;
-
- /// Map from instructions to their produced column matrix.
- MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
-
-public:
- LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
+ "block");
+ }
+ };
+
+ struct ShapeInfo {
+ unsigned NumRows;
+ unsigned NumColumns;
+
+ bool IsColumnMajor;
+
+ ShapeInfo(unsigned NumRows = 0, unsigned NumColumns = 0)
+ : NumRows(NumRows), NumColumns(NumColumns),
+ IsColumnMajor(MatrixLayout == MatrixLayoutTy::ColumnMajor) {}
+
+ ShapeInfo(Value *NumRows, Value *NumColumns)
+ : ShapeInfo(cast<ConstantInt>(NumRows)->getZExtValue(),
+ cast<ConstantInt>(NumColumns)->getZExtValue()) {}
+
+ bool operator==(const ShapeInfo &other) {
+ return NumRows == other.NumRows && NumColumns == other.NumColumns;
+ }
+ bool operator!=(const ShapeInfo &other) { return !(*this == other); }
+
+ /// Returns true if shape-information is defined, meaning both dimensions
+ /// are != 0.
+ operator bool() const {
+ assert(NumRows == 0 || NumColumns != 0);
+ return NumRows != 0;
+ }
+
+ unsigned getStride() const {
+ if (IsColumnMajor)
+ return NumRows;
+ return NumColumns;
+ }
+
+ unsigned getNumVectors() const {
+ if (IsColumnMajor)
+ return NumColumns;
+ return NumRows;
+ }
+ };
+
+ /// Maps instructions to their shape information. The shape information
+ /// describes the shape to be used while lowering. This matches the shape of
+ /// the result value of the instruction, with the only exceptions being store
+ /// instructions and the matrix_column_major_store intrinsics. For those, the
+ /// shape information indicates that those instructions should be lowered
+ /// using shape information as well.
+ DenseMap<Value *, ShapeInfo> ShapeMap;
+
+ /// List of instructions to remove. While lowering, we are not replacing all
+ /// users of a lowered instruction, if shape information is available and
+ /// those need to be removed after we finished lowering.
+ SmallVector<Instruction *, 16> ToRemove;
+
+ /// Map from instructions to their produced column matrix.
+ MapVector<Value *, MatrixTy> Inst2ColumnMatrix;
+
+public:
+ LowerMatrixIntrinsics(Function &F, TargetTransformInfo &TTI,
AliasAnalysis *AA, DominatorTree *DT, LoopInfo *LI,
OptimizationRemarkEmitter *ORE)
- : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
- LI(LI), ORE(ORE) {}
-
- unsigned getNumOps(Type *VT) {
- assert(isa<VectorType>(VT) && "Expected vector type");
- return getNumOps(VT->getScalarType(),
- cast<FixedVectorType>(VT)->getNumElements());
- }
-
- //
- /// Return the estimated number of vector ops required for an operation on
- /// \p VT * N.
- unsigned getNumOps(Type *ST, unsigned N) {
- return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
- double(TTI.getRegisterBitWidth(true)));
- }
-
- /// Return the set of vectors that a matrix value is lowered to.
- ///
- /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
- /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
- /// into vectors.
- MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
- IRBuilder<> &Builder) {
- VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
- assert(VType && "MatrixVal must be a vector type");
- assert(cast<FixedVectorType>(VType)->getNumElements() ==
- SI.NumRows * SI.NumColumns &&
- "The vector size must match the number of matrix elements");
-
- // Check if we lowered MatrixVal using shape information. In that case,
- // return the existing matrix, if it matches the requested shape
- // information. If there is a mis-match, embed the result in a flat
- // vector and split it later.
- auto Found = Inst2ColumnMatrix.find(MatrixVal);
- if (Found != Inst2ColumnMatrix.end()) {
- MatrixTy &M = Found->second;
- // Return the found matrix, if its shape matches the requested shape
- // information
- if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
- return M;
-
- MatrixVal = M.embedInVector(Builder);
- }
-
- // Otherwise split MatrixVal.
- SmallVector<Value *, 16> SplitVecs;
- for (unsigned MaskStart = 0;
- MaskStart < cast<FixedVectorType>(VType)->getNumElements();
- MaskStart += SI.getStride()) {
- Value *V = Builder.CreateShuffleVector(
+ : Func(F), DL(F.getParent()->getDataLayout()), TTI(TTI), AA(AA), DT(DT),
+ LI(LI), ORE(ORE) {}
+
+ unsigned getNumOps(Type *VT) {
+ assert(isa<VectorType>(VT) && "Expected vector type");
+ return getNumOps(VT->getScalarType(),
+ cast<FixedVectorType>(VT)->getNumElements());
+ }
+
+ //
+ /// Return the estimated number of vector ops required for an operation on
+ /// \p VT * N.
+ unsigned getNumOps(Type *ST, unsigned N) {
+ return std::ceil((ST->getPrimitiveSizeInBits() * N).getFixedSize() /
+ double(TTI.getRegisterBitWidth(true)));
+ }
+
+ /// Return the set of vectors that a matrix value is lowered to.
+ ///
+ /// If we lowered \p MatrixVal, just return the cache result matrix. Otherwise
+ /// split the flat vector \p MatrixVal containing a matrix with shape \p SI
+ /// into vectors.
+ MatrixTy getMatrix(Value *MatrixVal, const ShapeInfo &SI,
+ IRBuilder<> &Builder) {
+ VectorType *VType = dyn_cast<VectorType>(MatrixVal->getType());
+ assert(VType && "MatrixVal must be a vector type");
+ assert(cast<FixedVectorType>(VType)->getNumElements() ==
+ SI.NumRows * SI.NumColumns &&
+ "The vector size must match the number of matrix elements");
+
+ // Check if we lowered MatrixVal using shape information. In that case,
+ // return the existing matrix, if it matches the requested shape
+ // information. If there is a mis-match, embed the result in a flat
+ // vector and split it later.
+ auto Found = Inst2ColumnMatrix.find(MatrixVal);
+ if (Found != Inst2ColumnMatrix.end()) {
+ MatrixTy &M = Found->second;
+ // Return the found matrix, if its shape matches the requested shape
+ // information
+ if (SI.NumRows == M.getNumRows() && SI.NumColumns == M.getNumColumns())
+ return M;
+
+ MatrixVal = M.embedInVector(Builder);
+ }
+
+ // Otherwise split MatrixVal.
+ SmallVector<Value *, 16> SplitVecs;
+ for (unsigned MaskStart = 0;
+ MaskStart < cast<FixedVectorType>(VType)->getNumElements();
+ MaskStart += SI.getStride()) {
+ Value *V = Builder.CreateShuffleVector(
MatrixVal, createSequentialMask(MaskStart, SI.getStride(), 0),
- "split");
- SplitVecs.push_back(V);
- }
-
- return {SplitVecs};
- }
-
- /// If \p V already has a known shape return false. Otherwise set the shape
- /// for instructions that support it.
- bool setShapeInfo(Value *V, ShapeInfo Shape) {
- assert(Shape && "Shape not set");
- if (isa<UndefValue>(V) || !supportsShapeInfo(V))
- return false;
-
- auto SIter = ShapeMap.find(V);
- if (SIter != ShapeMap.end()) {
- LLVM_DEBUG(dbgs() << " not overriding existing shape: "
- << SIter->second.NumRows << " "
- << SIter->second.NumColumns << " for " << *V << "\n");
- return false;
- }
-
- ShapeMap.insert({V, Shape});
- LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
- << " for " << *V << "\n");
- return true;
- }
-
- bool isUniformShape(Value *V) {
- Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return true;
-
- switch (I->getOpcode()) {
- case Instruction::FAdd:
- case Instruction::FSub:
- case Instruction::FMul: // Scalar multiply.
+ "split");
+ SplitVecs.push_back(V);
+ }
+
+ return {SplitVecs};
+ }
+
+ /// If \p V already has a known shape return false. Otherwise set the shape
+ /// for instructions that support it.
+ bool setShapeInfo(Value *V, ShapeInfo Shape) {
+ assert(Shape && "Shape not set");
+ if (isa<UndefValue>(V) || !supportsShapeInfo(V))
+ return false;
+
+ auto SIter = ShapeMap.find(V);
+ if (SIter != ShapeMap.end()) {
+ LLVM_DEBUG(dbgs() << " not overriding existing shape: "
+ << SIter->second.NumRows << " "
+ << SIter->second.NumColumns << " for " << *V << "\n");
+ return false;
+ }
+
+ ShapeMap.insert({V, Shape});
+ LLVM_DEBUG(dbgs() << " " << Shape.NumRows << " x " << Shape.NumColumns
+ << " for " << *V << "\n");
+ return true;
+ }
+
+ bool isUniformShape(Value *V) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (!I)
+ return true;
+
+ switch (I->getOpcode()) {
+ case Instruction::FAdd:
+ case Instruction::FSub:
+ case Instruction::FMul: // Scalar multiply.
case Instruction::FNeg:
- case Instruction::Add:
- case Instruction::Mul:
- case Instruction::Sub:
- return true;
- default:
- return false;
- }
- }
-
- /// Returns true if shape information can be used for \p V. The supported
- /// instructions must match the instructions that can be lowered by this pass.
- bool supportsShapeInfo(Value *V) {
- Instruction *Inst = dyn_cast<Instruction>(V);
- if (!Inst)
- return false;
-
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
- if (II)
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_column_major_load:
- case Intrinsic::matrix_column_major_store:
- return true;
- default:
- return false;
- }
- return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
- }
-
- /// Propagate the shape information of instructions to their users.
- /// The work list contains instructions for which we can compute the shape,
- /// either based on the information provided by matrix intrinsics or known
- /// shapes of operands.
- SmallVector<Instruction *, 32>
- propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
- SmallVector<Instruction *, 32> NewWorkList;
- // Pop an element for which we guaranteed to have at least one of the
- // operand shapes. Add the shape for this and then add users to the work
- // list.
- LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
- while (!WorkList.empty()) {
+ case Instruction::Add:
+ case Instruction::Mul:
+ case Instruction::Sub:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ /// Returns true if shape information can be used for \p V. The supported
+ /// instructions must match the instructions that can be lowered by this pass.
+ bool supportsShapeInfo(Value *V) {
+ Instruction *Inst = dyn_cast<Instruction>(V);
+ if (!Inst)
+ return false;
+
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst);
+ if (II)
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ case Intrinsic::matrix_transpose:
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
+ return true;
+ default:
+ return false;
+ }
+ return isUniformShape(V) || isa<StoreInst>(V) || isa<LoadInst>(V);
+ }
+
+ /// Propagate the shape information of instructions to their users.
+ /// The work list contains instructions for which we can compute the shape,
+ /// either based on the information provided by matrix intrinsics or known
+ /// shapes of operands.
+ SmallVector<Instruction *, 32>
+ propagateShapeForward(SmallVectorImpl<Instruction *> &WorkList) {
+ SmallVector<Instruction *, 32> NewWorkList;
+ // Pop an element for which we guaranteed to have at least one of the
+ // operand shapes. Add the shape for this and then add users to the work
+ // list.
+ LLVM_DEBUG(dbgs() << "Forward-propagate shapes:\n");
+ while (!WorkList.empty()) {
Instruction *Inst = WorkList.pop_back_val();
-
- // New entry, set the value and insert operands
- bool Propagate = false;
-
- Value *MatrixA;
- Value *MatrixB;
- Value *M;
- Value *N;
- Value *K;
- if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
- m_Value(N), m_Value(K)))) {
- Propagate = setShapeInfo(Inst, {M, K});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
- m_Value(MatrixA), m_Value(), m_Value(),
- m_Value(), m_Value(M), m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {N, M});
- } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
- m_Value(), m_Value(), m_Value(), m_Value(M),
- m_Value(N)))) {
- Propagate = setShapeInfo(Inst, {M, N});
- } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
- auto OpShape = ShapeMap.find(MatrixA);
- if (OpShape != ShapeMap.end())
- setShapeInfo(Inst, OpShape->second);
- continue;
- } else if (isUniformShape(Inst)) {
- // Find the first operand that has a known shape and use that.
- for (auto &Op : Inst->operands()) {
- auto OpShape = ShapeMap.find(Op.get());
- if (OpShape != ShapeMap.end()) {
- Propagate |= setShapeInfo(Inst, OpShape->second);
- break;
- }
- }
- }
-
- if (Propagate) {
- NewWorkList.push_back(Inst);
- for (auto *User : Inst->users())
- if (ShapeMap.count(User) == 0)
- WorkList.push_back(cast<Instruction>(User));
- }
- }
-
- return NewWorkList;
- }
-
- /// Propagate the shape to operands of instructions with shape information.
- /// \p Worklist contains the instruction for which we already know the shape.
- SmallVector<Instruction *, 32>
- propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
- SmallVector<Instruction *, 32> NewWorkList;
-
- auto pushInstruction = [](Value *V,
- SmallVectorImpl<Instruction *> &WorkList) {
- Instruction *I = dyn_cast<Instruction>(V);
- if (I)
- WorkList.push_back(I);
- };
- // Pop an element with known shape. Traverse the operands, if their shape
- // derives from the result shape and is unknown, add it and add them to the
- // worklist.
- LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
- while (!WorkList.empty()) {
+
+ // New entry, set the value and insert operands
+ bool Propagate = false;
+
+ Value *MatrixA;
+ Value *MatrixB;
+ Value *M;
+ Value *N;
+ Value *K;
+ if (match(Inst, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
+ m_Value(N), m_Value(K)))) {
+ Propagate = setShapeInfo(Inst, {M, K});
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(MatrixA), m_Value(M), m_Value(N)))) {
+ // Flip dimensions.
+ Propagate = setShapeInfo(Inst, {N, M});
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+ m_Value(MatrixA), m_Value(), m_Value(),
+ m_Value(), m_Value(M), m_Value(N)))) {
+ Propagate = setShapeInfo(Inst, {N, M});
+ } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_column_major_load>(
+ m_Value(), m_Value(), m_Value(), m_Value(M),
+ m_Value(N)))) {
+ Propagate = setShapeInfo(Inst, {M, N});
+ } else if (match(Inst, m_Store(m_Value(MatrixA), m_Value()))) {
+ auto OpShape = ShapeMap.find(MatrixA);
+ if (OpShape != ShapeMap.end())
+ setShapeInfo(Inst, OpShape->second);
+ continue;
+ } else if (isUniformShape(Inst)) {
+ // Find the first operand that has a known shape and use that.
+ for (auto &Op : Inst->operands()) {
+ auto OpShape = ShapeMap.find(Op.get());
+ if (OpShape != ShapeMap.end()) {
+ Propagate |= setShapeInfo(Inst, OpShape->second);
+ break;
+ }
+ }
+ }
+
+ if (Propagate) {
+ NewWorkList.push_back(Inst);
+ for (auto *User : Inst->users())
+ if (ShapeMap.count(User) == 0)
+ WorkList.push_back(cast<Instruction>(User));
+ }
+ }
+
+ return NewWorkList;
+ }
+
+ /// Propagate the shape to operands of instructions with shape information.
+ /// \p Worklist contains the instruction for which we already know the shape.
+ SmallVector<Instruction *, 32>
+ propagateShapeBackward(SmallVectorImpl<Instruction *> &WorkList) {
+ SmallVector<Instruction *, 32> NewWorkList;
+
+ auto pushInstruction = [](Value *V,
+ SmallVectorImpl<Instruction *> &WorkList) {
+ Instruction *I = dyn_cast<Instruction>(V);
+ if (I)
+ WorkList.push_back(I);
+ };
+ // Pop an element with known shape. Traverse the operands, if their shape
+ // derives from the result shape and is unknown, add it and add them to the
+ // worklist.
+ LLVM_DEBUG(dbgs() << "Backward-propagate shapes:\n");
+ while (!WorkList.empty()) {
Value *V = WorkList.pop_back_val();
-
- size_t BeforeProcessingV = WorkList.size();
- if (!isa<Instruction>(V))
- continue;
-
- Value *MatrixA;
- Value *MatrixB;
- Value *M;
- Value *N;
- Value *K;
- if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
- m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
- m_Value(N), m_Value(K)))) {
- if (setShapeInfo(MatrixA, {M, N}))
- pushInstruction(MatrixA, WorkList);
-
- if (setShapeInfo(MatrixB, {N, K}))
- pushInstruction(MatrixB, WorkList);
-
- } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
- m_Value(MatrixA), m_Value(M), m_Value(N)))) {
- // Flip dimensions.
- if (setShapeInfo(MatrixA, {M, N}))
- pushInstruction(MatrixA, WorkList);
- } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
- m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
- m_Value(M), m_Value(N)))) {
- if (setShapeInfo(MatrixA, {M, N})) {
- pushInstruction(MatrixA, WorkList);
- }
- } else if (isa<LoadInst>(V) ||
- match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
- // Nothing to do, no matrix input.
- } else if (isa<StoreInst>(V)) {
- // Nothing to do. We forward-propagated to this so we would just
- // backward propagate to an instruction with an already known shape.
- } else if (isUniformShape(V)) {
- // Propagate to all operands.
- ShapeInfo Shape = ShapeMap[V];
- for (Use &U : cast<Instruction>(V)->operands()) {
- if (setShapeInfo(U.get(), Shape))
- pushInstruction(U.get(), WorkList);
- }
- }
- // After we discovered new shape info for new instructions in the
- // worklist, we use their users as seeds for the next round of forward
- // propagation.
- for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
- for (User *U : WorkList[I]->users())
- if (isa<Instruction>(U) && V != U)
- NewWorkList.push_back(cast<Instruction>(U));
- }
- return NewWorkList;
- }
-
- bool Visit() {
- if (EnableShapePropagation) {
- SmallVector<Instruction *, 32> WorkList;
-
- // Initially only the shape of matrix intrinsics is known.
- // Initialize the work list with ops carrying shape information.
- for (BasicBlock &BB : Func)
- for (Instruction &Inst : BB) {
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
- if (!II)
- continue;
-
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- case Intrinsic::matrix_transpose:
- case Intrinsic::matrix_column_major_load:
- case Intrinsic::matrix_column_major_store:
- WorkList.push_back(&Inst);
- break;
- default:
- break;
- }
- }
- // Propagate shapes until nothing changes any longer.
- while (!WorkList.empty()) {
- WorkList = propagateShapeForward(WorkList);
- WorkList = propagateShapeBackward(WorkList);
- }
- }
-
- bool Changed = false;
- SmallVector<CallInst *, 16> MaybeFusableInsts;
- SmallVector<Instruction *, 16> MatrixInsts;
-
- // First, collect all instructions with shape information and candidates for
- // fusion (currently only matrix multiplies).
- ReversePostOrderTraversal<Function *> RPOT(&Func);
- for (auto *BB : RPOT)
- for (Instruction &I : *BB) {
- if (ShapeMap.find(&I) == ShapeMap.end())
- continue;
- if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
- MaybeFusableInsts.push_back(cast<CallInst>(&I));
- MatrixInsts.push_back(&I);
- }
-
- // Second, try to fuse candidates.
- SmallPtrSet<Instruction *, 16> FusedInsts;
- for (CallInst *CI : MaybeFusableInsts)
- LowerMatrixMultiplyFused(CI, FusedInsts);
- Changed = !FusedInsts.empty();
-
- // Third, lower remaining instructions with shape information.
- for (Instruction *Inst : MatrixInsts) {
- if (FusedInsts.count(Inst))
- continue;
-
- IRBuilder<> Builder(Inst);
-
- if (CallInst *CInst = dyn_cast<CallInst>(Inst))
- Changed |= VisitCallInst(CInst);
-
- Value *Op1;
- Value *Op2;
- if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
- Changed |= VisitBinaryOperator(BinOp);
+
+ size_t BeforeProcessingV = WorkList.size();
+ if (!isa<Instruction>(V))
+ continue;
+
+ Value *MatrixA;
+ Value *MatrixB;
+ Value *M;
+ Value *N;
+ Value *K;
+ if (match(V, m_Intrinsic<Intrinsic::matrix_multiply>(
+ m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
+ m_Value(N), m_Value(K)))) {
+ if (setShapeInfo(MatrixA, {M, N}))
+ pushInstruction(MatrixA, WorkList);
+
+ if (setShapeInfo(MatrixB, {N, K}))
+ pushInstruction(MatrixB, WorkList);
+
+ } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
+ m_Value(MatrixA), m_Value(M), m_Value(N)))) {
+ // Flip dimensions.
+ if (setShapeInfo(MatrixA, {M, N}))
+ pushInstruction(MatrixA, WorkList);
+ } else if (match(V, m_Intrinsic<Intrinsic::matrix_column_major_store>(
+ m_Value(MatrixA), m_Value(), m_Value(), m_Value(),
+ m_Value(M), m_Value(N)))) {
+ if (setShapeInfo(MatrixA, {M, N})) {
+ pushInstruction(MatrixA, WorkList);
+ }
+ } else if (isa<LoadInst>(V) ||
+ match(V, m_Intrinsic<Intrinsic::matrix_column_major_load>())) {
+ // Nothing to do, no matrix input.
+ } else if (isa<StoreInst>(V)) {
+ // Nothing to do. We forward-propagated to this so we would just
+ // backward propagate to an instruction with an already known shape.
+ } else if (isUniformShape(V)) {
+ // Propagate to all operands.
+ ShapeInfo Shape = ShapeMap[V];
+ for (Use &U : cast<Instruction>(V)->operands()) {
+ if (setShapeInfo(U.get(), Shape))
+ pushInstruction(U.get(), WorkList);
+ }
+ }
+ // After we discovered new shape info for new instructions in the
+ // worklist, we use their users as seeds for the next round of forward
+ // propagation.
+ for (size_t I = BeforeProcessingV; I != WorkList.size(); I++)
+ for (User *U : WorkList[I]->users())
+ if (isa<Instruction>(U) && V != U)
+ NewWorkList.push_back(cast<Instruction>(U));
+ }
+ return NewWorkList;
+ }
+
+ bool Visit() {
+ if (EnableShapePropagation) {
+ SmallVector<Instruction *, 32> WorkList;
+
+ // Initially only the shape of matrix intrinsics is known.
+ // Initialize the work list with ops carrying shape information.
+ for (BasicBlock &BB : Func)
+ for (Instruction &Inst : BB) {
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(&Inst);
+ if (!II)
+ continue;
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ case Intrinsic::matrix_transpose:
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
+ WorkList.push_back(&Inst);
+ break;
+ default:
+ break;
+ }
+ }
+ // Propagate shapes until nothing changes any longer.
+ while (!WorkList.empty()) {
+ WorkList = propagateShapeForward(WorkList);
+ WorkList = propagateShapeBackward(WorkList);
+ }
+ }
+
+ bool Changed = false;
+ SmallVector<CallInst *, 16> MaybeFusableInsts;
+ SmallVector<Instruction *, 16> MatrixInsts;
+
+ // First, collect all instructions with shape information and candidates for
+ // fusion (currently only matrix multiplies).
+ ReversePostOrderTraversal<Function *> RPOT(&Func);
+ for (auto *BB : RPOT)
+ for (Instruction &I : *BB) {
+ if (ShapeMap.find(&I) == ShapeMap.end())
+ continue;
+ if (match(&I, m_Intrinsic<Intrinsic::matrix_multiply>()))
+ MaybeFusableInsts.push_back(cast<CallInst>(&I));
+ MatrixInsts.push_back(&I);
+ }
+
+ // Second, try to fuse candidates.
+ SmallPtrSet<Instruction *, 16> FusedInsts;
+ for (CallInst *CI : MaybeFusableInsts)
+ LowerMatrixMultiplyFused(CI, FusedInsts);
+ Changed = !FusedInsts.empty();
+
+ // Third, lower remaining instructions with shape information.
+ for (Instruction *Inst : MatrixInsts) {
+ if (FusedInsts.count(Inst))
+ continue;
+
+ IRBuilder<> Builder(Inst);
+
+ if (CallInst *CInst = dyn_cast<CallInst>(Inst))
+ Changed |= VisitCallInst(CInst);
+
+ Value *Op1;
+ Value *Op2;
+ if (auto *BinOp = dyn_cast<BinaryOperator>(Inst))
+ Changed |= VisitBinaryOperator(BinOp);
if (auto *UnOp = dyn_cast<UnaryOperator>(Inst))
Changed |= VisitUnaryOperator(UnOp);
- if (match(Inst, m_Load(m_Value(Op1))))
- Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
- else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
- Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
- }
-
+ if (match(Inst, m_Load(m_Value(Op1))))
+ Changed |= VisitLoad(cast<LoadInst>(Inst), Op1, Builder);
+ else if (match(Inst, m_Store(m_Value(Op1), m_Value(Op2))))
+ Changed |= VisitStore(cast<StoreInst>(Inst), Op1, Op2, Builder);
+ }
+
if (ORE) {
RemarkGenerator RemarkGen(Inst2ColumnMatrix, *ORE, Func);
RemarkGen.emitRemarks();
}
-
- for (Instruction *Inst : reverse(ToRemove))
- Inst->eraseFromParent();
-
- return Changed;
- }
-
- /// Turns \p BasePtr into an elementwise pointer to \p EltType.
- Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
- unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
- Type *EltPtrType = PointerType::get(EltType, AS);
- return Builder.CreatePointerCast(BasePtr, EltPtrType);
- }
-
- /// Replace intrinsic calls
- bool VisitCallInst(CallInst *Inst) {
- if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
- return false;
-
- switch (Inst->getCalledFunction()->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- LowerMultiply(Inst);
- break;
- case Intrinsic::matrix_transpose:
- LowerTranspose(Inst);
- break;
- case Intrinsic::matrix_column_major_load:
- LowerColumnMajorLoad(Inst);
- break;
- case Intrinsic::matrix_column_major_store:
- LowerColumnMajorStore(Inst);
- break;
- default:
- return false;
- }
- return true;
- }
-
- /// Compute the alignment for a column/row \p Idx with \p Stride between them.
- /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
- /// ConstantInt, reduce the initial alignment based on the byte offset. For
- /// non-ConstantInt strides, return the common alignment of the initial
- /// alignment and the element size in bytes.
- Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
- MaybeAlign A) const {
- Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
- if (Idx == 0)
- return InitialAlign;
-
- TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
- if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
- uint64_t StrideInBytes =
- ConstStride->getZExtValue() * ElementSizeInBits / 8;
- return commonAlignment(InitialAlign, Idx * StrideInBytes);
- }
- return commonAlignment(InitialAlign, ElementSizeInBits / 8);
- }
-
- /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
- /// vectors.
- MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
- bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
- auto VType = cast<VectorType>(Ty);
- Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- MatrixTy Result;
- for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
- Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
- Shape.getStride(), VType->getElementType(),
- Builder);
- Value *Vector = Builder.CreateAlignedLoad(
- GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
- IsVolatile, "col.load");
-
- Result.addVector(Vector);
- }
- return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors());
- }
-
- /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
- /// starting at \p MatrixPtr[I][J].
- MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
- ShapeInfo MatrixShape, Value *I, Value *J,
- ShapeInfo ResultShape, Type *EltTy,
- IRBuilder<> &Builder) {
-
- Value *Offset = Builder.CreateAdd(
- Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
-
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
- auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
- ResultShape.NumColumns);
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
-
- return loadMatrix(TileTy, TilePtr, Align,
- Builder.getInt64(MatrixShape.getStride()), IsVolatile,
- ResultShape, Builder);
- }
-
- /// Lower a load instruction with shape information.
- void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
- bool IsVolatile, ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- finalizeLowering(Inst,
- loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
- Shape, Builder),
- Builder);
- }
-
- /// Lowers llvm.matrix.column.major.load.
- ///
- /// The intrinsic loads a matrix from memory using a stride between columns.
- void LowerColumnMajorLoad(CallInst *Inst) {
- assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
- "Intrinsic only supports column-major layout!");
- Value *Ptr = Inst->getArgOperand(0);
- Value *Stride = Inst->getArgOperand(1);
- LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
- cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
- {Inst->getArgOperand(3), Inst->getArgOperand(4)});
- }
-
- /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
- /// MatrixPtr[I][J].
- void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
- MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
- Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
- Value *Offset = Builder.CreateAdd(
- Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
-
- unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
- Value *EltPtr =
- Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
- Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
- auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
- StoreVal.getNumColumns());
- Type *TilePtrTy = PointerType::get(TileTy, AS);
- Value *TilePtr =
- Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
-
- storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
- Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
- }
-
- /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
- /// vectors.
- MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
- MaybeAlign MAlign, Value *Stride, bool IsVolatile,
- IRBuilder<> &Builder) {
- auto VType = cast<VectorType>(Ty);
- Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
- for (auto Vec : enumerate(StoreVal.vectors())) {
- Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
- Stride, StoreVal.getStride(),
- VType->getElementType(), Builder);
- Builder.CreateAlignedStore(Vec.value(), GEP,
- getAlignForIndex(Vec.index(), Stride,
- VType->getElementType(),
- MAlign),
- IsVolatile);
- }
- return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
- StoreVal.getNumVectors());
- }
-
- /// Lower a store instruction with shape information.
- void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
- Value *Stride, bool IsVolatile, ShapeInfo Shape) {
- IRBuilder<> Builder(Inst);
- auto StoreVal = getMatrix(Matrix, Shape, Builder);
- finalizeLowering(Inst,
- storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
- IsVolatile, Builder),
- Builder);
- }
-
- /// Lowers llvm.matrix.column.major.store.
- ///
- /// The intrinsic store a matrix back memory using a stride between columns.
- void LowerColumnMajorStore(CallInst *Inst) {
- assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
- "Intrinsic only supports column-major layout!");
- Value *Matrix = Inst->getArgOperand(0);
- Value *Ptr = Inst->getArgOperand(1);
- Value *Stride = Inst->getArgOperand(2);
- LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
- cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
- {Inst->getArgOperand(4), Inst->getArgOperand(5)});
- }
-
- // Set elements I..I+NumElts-1 to Block
- Value *insertVector(Value *Col, unsigned I, Value *Block,
- IRBuilder<> &Builder) {
-
- // First, bring Block to the same size as Col
- unsigned BlockNumElts =
- cast<FixedVectorType>(Block->getType())->getNumElements();
- unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
- assert(NumElts >= BlockNumElts && "Too few elements for current block");
-
- Block = Builder.CreateShuffleVector(
+
+ for (Instruction *Inst : reverse(ToRemove))
+ Inst->eraseFromParent();
+
+ return Changed;
+ }
+
+ /// Turns \p BasePtr into an elementwise pointer to \p EltType.
+ Value *createElementPtr(Value *BasePtr, Type *EltType, IRBuilder<> &Builder) {
+ unsigned AS = cast<PointerType>(BasePtr->getType())->getAddressSpace();
+ Type *EltPtrType = PointerType::get(EltType, AS);
+ return Builder.CreatePointerCast(BasePtr, EltPtrType);
+ }
+
+ /// Replace intrinsic calls
+ bool VisitCallInst(CallInst *Inst) {
+ if (!Inst->getCalledFunction() || !Inst->getCalledFunction()->isIntrinsic())
+ return false;
+
+ switch (Inst->getCalledFunction()->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ LowerMultiply(Inst);
+ break;
+ case Intrinsic::matrix_transpose:
+ LowerTranspose(Inst);
+ break;
+ case Intrinsic::matrix_column_major_load:
+ LowerColumnMajorLoad(Inst);
+ break;
+ case Intrinsic::matrix_column_major_store:
+ LowerColumnMajorStore(Inst);
+ break;
+ default:
+ return false;
+ }
+ return true;
+ }
+
+ /// Compute the alignment for a column/row \p Idx with \p Stride between them.
+ /// The address at \p Idx == 0 has alignment \p A. If \p Stride is a
+ /// ConstantInt, reduce the initial alignment based on the byte offset. For
+ /// non-ConstantInt strides, return the common alignment of the initial
+ /// alignment and the element size in bytes.
+ Align getAlignForIndex(unsigned Idx, Value *Stride, Type *ElementTy,
+ MaybeAlign A) const {
+ Align InitialAlign = DL.getValueOrABITypeAlignment(A, ElementTy);
+ if (Idx == 0)
+ return InitialAlign;
+
+ TypeSize ElementSizeInBits = DL.getTypeSizeInBits(ElementTy);
+ if (auto *ConstStride = dyn_cast<ConstantInt>(Stride)) {
+ uint64_t StrideInBytes =
+ ConstStride->getZExtValue() * ElementSizeInBits / 8;
+ return commonAlignment(InitialAlign, Idx * StrideInBytes);
+ }
+ return commonAlignment(InitialAlign, ElementSizeInBits / 8);
+ }
+
+ /// Load a matrix with \p Shape starting at \p Ptr and using \p Stride between
+ /// vectors.
+ MatrixTy loadMatrix(Type *Ty, Value *Ptr, MaybeAlign MAlign, Value *Stride,
+ bool IsVolatile, ShapeInfo Shape, IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
+ Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
+ MatrixTy Result;
+ for (unsigned I = 0, E = Shape.getNumVectors(); I < E; ++I) {
+ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(I), Stride,
+ Shape.getStride(), VType->getElementType(),
+ Builder);
+ Value *Vector = Builder.CreateAlignedLoad(
+ GEP, getAlignForIndex(I, Stride, VType->getElementType(), MAlign),
+ IsVolatile, "col.load");
+
+ Result.addVector(Vector);
+ }
+ return Result.addNumLoads(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors());
+ }
+
+ /// Loads a sub-matrix with shape \p ResultShape from a \p R x \p C matrix,
+ /// starting at \p MatrixPtr[I][J].
+ MatrixTy loadMatrix(Value *MatrixPtr, MaybeAlign Align, bool IsVolatile,
+ ShapeInfo MatrixShape, Value *I, Value *J,
+ ShapeInfo ResultShape, Type *EltTy,
+ IRBuilder<> &Builder) {
+
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ auto *TileTy = FixedVectorType::get(EltTy, ResultShape.NumRows *
+ ResultShape.NumColumns);
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ return loadMatrix(TileTy, TilePtr, Align,
+ Builder.getInt64(MatrixShape.getStride()), IsVolatile,
+ ResultShape, Builder);
+ }
+
+ /// Lower a load instruction with shape information.
+ void LowerLoad(Instruction *Inst, Value *Ptr, MaybeAlign Align, Value *Stride,
+ bool IsVolatile, ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ finalizeLowering(Inst,
+ loadMatrix(Inst->getType(), Ptr, Align, Stride, IsVolatile,
+ Shape, Builder),
+ Builder);
+ }
+
+ /// Lowers llvm.matrix.column.major.load.
+ ///
+ /// The intrinsic loads a matrix from memory using a stride between columns.
+ void LowerColumnMajorLoad(CallInst *Inst) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Intrinsic only supports column-major layout!");
+ Value *Ptr = Inst->getArgOperand(0);
+ Value *Stride = Inst->getArgOperand(1);
+ LowerLoad(Inst, Ptr, Inst->getParamAlign(0), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(2))->isOne(),
+ {Inst->getArgOperand(3), Inst->getArgOperand(4)});
+ }
+
+ /// Stores a sub-matrix \p StoreVal into the \p R x \p C matrix starting at \p
+ /// MatrixPtr[I][J].
+ void storeMatrix(const MatrixTy &StoreVal, Value *MatrixPtr,
+ MaybeAlign MAlign, bool IsVolatile, ShapeInfo MatrixShape,
+ Value *I, Value *J, Type *EltTy, IRBuilder<> &Builder) {
+ Value *Offset = Builder.CreateAdd(
+ Builder.CreateMul(J, Builder.getInt64(MatrixShape.getStride())), I);
+
+ unsigned AS = cast<PointerType>(MatrixPtr->getType())->getAddressSpace();
+ Value *EltPtr =
+ Builder.CreatePointerCast(MatrixPtr, PointerType::get(EltTy, AS));
+ Value *TileStart = Builder.CreateGEP(EltTy, EltPtr, Offset);
+ auto *TileTy = FixedVectorType::get(EltTy, StoreVal.getNumRows() *
+ StoreVal.getNumColumns());
+ Type *TilePtrTy = PointerType::get(TileTy, AS);
+ Value *TilePtr =
+ Builder.CreatePointerCast(TileStart, TilePtrTy, "col.cast");
+
+ storeMatrix(TileTy, StoreVal, TilePtr, MAlign,
+ Builder.getInt64(MatrixShape.getStride()), IsVolatile, Builder);
+ }
+
+ /// Store matrix \p StoreVal starting at \p Ptr and using \p Stride between
+ /// vectors.
+ MatrixTy storeMatrix(Type *Ty, MatrixTy StoreVal, Value *Ptr,
+ MaybeAlign MAlign, Value *Stride, bool IsVolatile,
+ IRBuilder<> &Builder) {
+ auto VType = cast<VectorType>(Ty);
+ Value *EltPtr = createElementPtr(Ptr, VType->getElementType(), Builder);
+ for (auto Vec : enumerate(StoreVal.vectors())) {
+ Value *GEP = computeVectorAddr(EltPtr, Builder.getInt64(Vec.index()),
+ Stride, StoreVal.getStride(),
+ VType->getElementType(), Builder);
+ Builder.CreateAlignedStore(Vec.value(), GEP,
+ getAlignForIndex(Vec.index(), Stride,
+ VType->getElementType(),
+ MAlign),
+ IsVolatile);
+ }
+ return MatrixTy().addNumStores(getNumOps(StoreVal.getVectorTy()) *
+ StoreVal.getNumVectors());
+ }
+
+ /// Lower a store instruction with shape information.
+ void LowerStore(Instruction *Inst, Value *Matrix, Value *Ptr, MaybeAlign A,
+ Value *Stride, bool IsVolatile, ShapeInfo Shape) {
+ IRBuilder<> Builder(Inst);
+ auto StoreVal = getMatrix(Matrix, Shape, Builder);
+ finalizeLowering(Inst,
+ storeMatrix(Matrix->getType(), StoreVal, Ptr, A, Stride,
+ IsVolatile, Builder),
+ Builder);
+ }
+
+ /// Lowers llvm.matrix.column.major.store.
+ ///
+ /// The intrinsic store a matrix back memory using a stride between columns.
+ void LowerColumnMajorStore(CallInst *Inst) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Intrinsic only supports column-major layout!");
+ Value *Matrix = Inst->getArgOperand(0);
+ Value *Ptr = Inst->getArgOperand(1);
+ Value *Stride = Inst->getArgOperand(2);
+ LowerStore(Inst, Matrix, Ptr, Inst->getParamAlign(1), Stride,
+ cast<ConstantInt>(Inst->getArgOperand(3))->isOne(),
+ {Inst->getArgOperand(4), Inst->getArgOperand(5)});
+ }
+
+ // Set elements I..I+NumElts-1 to Block
+ Value *insertVector(Value *Col, unsigned I, Value *Block,
+ IRBuilder<> &Builder) {
+
+ // First, bring Block to the same size as Col
+ unsigned BlockNumElts =
+ cast<FixedVectorType>(Block->getType())->getNumElements();
+ unsigned NumElts = cast<FixedVectorType>(Col->getType())->getNumElements();
+ assert(NumElts >= BlockNumElts && "Too few elements for current block");
+
+ Block = Builder.CreateShuffleVector(
Block, createSequentialMask(0, BlockNumElts, NumElts - BlockNumElts));
-
- // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
- // 8, 4, 5, 6
- SmallVector<int, 16> Mask;
- unsigned i;
- for (i = 0; i < I; i++)
- Mask.push_back(i);
-
- unsigned VecNumElts =
- cast<FixedVectorType>(Col->getType())->getNumElements();
- for (; i < I + BlockNumElts; i++)
- Mask.push_back(i - I + VecNumElts);
-
- for (; i < VecNumElts; i++)
- Mask.push_back(i);
-
- return Builder.CreateShuffleVector(Col, Block, Mask);
- }
-
- Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
- IRBuilder<> &Builder, bool AllowContraction,
- unsigned &NumComputeOps) {
- NumComputeOps += getNumOps(A->getType());
- if (!Sum)
- return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
-
- if (UseFPOp) {
- if (AllowContraction) {
- // Use fmuladd for floating point operations and let the backend decide
- // if that's profitable.
- Function *FMulAdd = Intrinsic::getDeclaration(
- Func.getParent(), Intrinsic::fmuladd, A->getType());
- return Builder.CreateCall(FMulAdd, {A, B, Sum});
- }
- NumComputeOps += getNumOps(A->getType());
- Value *Mul = Builder.CreateFMul(A, B);
- return Builder.CreateFAdd(Sum, Mul);
- }
-
- NumComputeOps += getNumOps(A->getType());
- Value *Mul = Builder.CreateMul(A, B);
- return Builder.CreateAdd(Sum, Mul);
- }
-
- /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
- /// users with shape information, there's nothing to do: the will use the
- /// cached value when they are lowered. For other users, \p Matrix is
- /// flattened and the uses are updated to use it. Also marks \p Inst for
- /// deletion.
- void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
- IRBuilder<> &Builder) {
- Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
-
- ToRemove.push_back(Inst);
- Value *Flattened = nullptr;
- for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
- Use &U = *I++;
- if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
- if (!Flattened)
- Flattened = Matrix.embedInVector(Builder);
- U.set(Flattened);
- }
- }
- }
-
- /// Compute \p Result += \p A * \p B for input matrices with left-associating
- /// addition.
- void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
- const MatrixTy &B, bool AllowContraction,
- IRBuilder<> &Builder, bool isTiled) {
- const unsigned VF = std::max<unsigned>(
- TTI.getRegisterBitWidth(true) /
- Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
- 1U);
- unsigned R = Result.getNumRows();
- unsigned C = Result.getNumColumns();
- unsigned M = A.getNumColumns();
-
- bool IsFP = Result.getElementType()->isFloatingPointTy();
- assert(A.isColumnMajor() == B.isColumnMajor() &&
- Result.isColumnMajor() == A.isColumnMajor() &&
- "operands must agree on matrix layout");
- unsigned NumComputeOps = 0;
- if (A.isColumnMajor()) {
- // Multiply columns from the first operand with scalars from the second
- // operand. Then move along the K axes and accumulate the columns. With
- // this the adds can be vectorized without reassociation.
- for (unsigned J = 0; J < C; ++J) {
- unsigned BlockSize = VF;
- // If Result is zero, we don't need to accumulate in the K==0 iteration.
- bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
-
- for (unsigned I = 0; I < R; I += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (I + BlockSize > R)
- BlockSize /= 2;
-
- Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder)
- : nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *L = A.extractVector(I, K, BlockSize, Builder);
- Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
- Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
- Result.getElementType()->isFloatingPointTy(),
- Builder, AllowContraction, NumComputeOps);
- }
- Result.setVector(J,
- insertVector(Result.getVector(J), I, Sum, Builder));
- }
- }
- } else {
- // Multiply rows from the second operand with scalars from the first
- // operand. Then move along the K axes and accumulate the rows. With this
- // the adds can be vectorized without reassociation.
- for (unsigned I = 0; I < R; ++I) {
- unsigned BlockSize = VF;
- bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
- for (unsigned J = 0; J < C; J += BlockSize) {
- // Gradually lower the vectorization factor to cover the remainder.
- while (J + BlockSize > C)
- BlockSize /= 2;
-
- Value *Sum = nullptr;
- for (unsigned K = 0; K < M; ++K) {
- Value *R = B.extractVector(K, J, BlockSize, Builder);
- Value *LH = Builder.CreateExtractElement(A.getVector(I), K);
- Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
- Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
- IsFP, Builder, AllowContraction, NumComputeOps);
- }
- Result.setVector(I,
- insertVector(Result.getVector(I), J, Sum, Builder));
- }
- }
- }
- Result.addNumComputeOps(NumComputeOps);
- }
-
- /// Ensure that the memory in \p Load does not alias \p Store by potentially
- /// copying it to a new location. This new or otherwise the original location
- /// is returned.
- Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
- CallInst *MatMul) {
- MemoryLocation StoreLoc = MemoryLocation::get(Store);
- MemoryLocation LoadLoc = MemoryLocation::get(Load);
-
+
+ // If Col is 7 long and I is 2 and BlockNumElts is 2 the mask is: 0, 1, 7,
+ // 8, 4, 5, 6
+ SmallVector<int, 16> Mask;
+ unsigned i;
+ for (i = 0; i < I; i++)
+ Mask.push_back(i);
+
+ unsigned VecNumElts =
+ cast<FixedVectorType>(Col->getType())->getNumElements();
+ for (; i < I + BlockNumElts; i++)
+ Mask.push_back(i - I + VecNumElts);
+
+ for (; i < VecNumElts; i++)
+ Mask.push_back(i);
+
+ return Builder.CreateShuffleVector(Col, Block, Mask);
+ }
+
+ Value *createMulAdd(Value *Sum, Value *A, Value *B, bool UseFPOp,
+ IRBuilder<> &Builder, bool AllowContraction,
+ unsigned &NumComputeOps) {
+ NumComputeOps += getNumOps(A->getType());
+ if (!Sum)
+ return UseFPOp ? Builder.CreateFMul(A, B) : Builder.CreateMul(A, B);
+
+ if (UseFPOp) {
+ if (AllowContraction) {
+ // Use fmuladd for floating point operations and let the backend decide
+ // if that's profitable.
+ Function *FMulAdd = Intrinsic::getDeclaration(
+ Func.getParent(), Intrinsic::fmuladd, A->getType());
+ return Builder.CreateCall(FMulAdd, {A, B, Sum});
+ }
+ NumComputeOps += getNumOps(A->getType());
+ Value *Mul = Builder.CreateFMul(A, B);
+ return Builder.CreateFAdd(Sum, Mul);
+ }
+
+ NumComputeOps += getNumOps(A->getType());
+ Value *Mul = Builder.CreateMul(A, B);
+ return Builder.CreateAdd(Sum, Mul);
+ }
+
+ /// Cache \p Matrix as result of \p Inst and update the uses of \p Inst. For
+ /// users with shape information, there's nothing to do: the will use the
+ /// cached value when they are lowered. For other users, \p Matrix is
+ /// flattened and the uses are updated to use it. Also marks \p Inst for
+ /// deletion.
+ void finalizeLowering(Instruction *Inst, MatrixTy Matrix,
+ IRBuilder<> &Builder) {
+ Inst2ColumnMatrix.insert(std::make_pair(Inst, Matrix));
+
+ ToRemove.push_back(Inst);
+ Value *Flattened = nullptr;
+ for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
+ Use &U = *I++;
+ if (ShapeMap.find(U.getUser()) == ShapeMap.end()) {
+ if (!Flattened)
+ Flattened = Matrix.embedInVector(Builder);
+ U.set(Flattened);
+ }
+ }
+ }
+
+ /// Compute \p Result += \p A * \p B for input matrices with left-associating
+ /// addition.
+ void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
+ const MatrixTy &B, bool AllowContraction,
+ IRBuilder<> &Builder, bool isTiled) {
+ const unsigned VF = std::max<unsigned>(
+ TTI.getRegisterBitWidth(true) /
+ Result.getElementType()->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+ unsigned R = Result.getNumRows();
+ unsigned C = Result.getNumColumns();
+ unsigned M = A.getNumColumns();
+
+ bool IsFP = Result.getElementType()->isFloatingPointTy();
+ assert(A.isColumnMajor() == B.isColumnMajor() &&
+ Result.isColumnMajor() == A.isColumnMajor() &&
+ "operands must agree on matrix layout");
+ unsigned NumComputeOps = 0;
+ if (A.isColumnMajor()) {
+ // Multiply columns from the first operand with scalars from the second
+ // operand. Then move along the K axes and accumulate the columns. With
+ // this the adds can be vectorized without reassociation.
+ for (unsigned J = 0; J < C; ++J) {
+ unsigned BlockSize = VF;
+ // If Result is zero, we don't need to accumulate in the K==0 iteration.
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
+
+ for (unsigned I = 0; I < R; I += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (I + BlockSize > R)
+ BlockSize /= 2;
+
+ Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder)
+ : nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *L = A.extractVector(I, K, BlockSize, Builder);
+ Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, RH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, L, Splat,
+ Result.getElementType()->isFloatingPointTy(),
+ Builder, AllowContraction, NumComputeOps);
+ }
+ Result.setVector(J,
+ insertVector(Result.getVector(J), I, Sum, Builder));
+ }
+ }
+ } else {
+ // Multiply rows from the second operand with scalars from the first
+ // operand. Then move along the K axes and accumulate the rows. With this
+ // the adds can be vectorized without reassociation.
+ for (unsigned I = 0; I < R; ++I) {
+ unsigned BlockSize = VF;
+ bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
+ for (unsigned J = 0; J < C; J += BlockSize) {
+ // Gradually lower the vectorization factor to cover the remainder.
+ while (J + BlockSize > C)
+ BlockSize /= 2;
+
+ Value *Sum = nullptr;
+ for (unsigned K = 0; K < M; ++K) {
+ Value *R = B.extractVector(K, J, BlockSize, Builder);
+ Value *LH = Builder.CreateExtractElement(A.getVector(I), K);
+ Value *Splat = Builder.CreateVectorSplat(BlockSize, LH, "splat");
+ Sum = createMulAdd(isSumZero && K == 0 ? nullptr : Sum, Splat, R,
+ IsFP, Builder, AllowContraction, NumComputeOps);
+ }
+ Result.setVector(I,
+ insertVector(Result.getVector(I), J, Sum, Builder));
+ }
+ }
+ }
+ Result.addNumComputeOps(NumComputeOps);
+ }
+
+ /// Ensure that the memory in \p Load does not alias \p Store by potentially
+ /// copying it to a new location. This new or otherwise the original location
+ /// is returned.
+ Value *getNonAliasingPointer(LoadInst *Load, StoreInst *Store,
+ CallInst *MatMul) {
+ MemoryLocation StoreLoc = MemoryLocation::get(Store);
+ MemoryLocation LoadLoc = MemoryLocation::get(Load);
+
AliasResult LdAliased = AA->alias(LoadLoc, StoreLoc);
-
- // If we can statically determine noalias we're good.
- if (!LdAliased)
- return Load->getPointerOperand();
-
- // Create code to check if the memory locations of the Load and Store
- // overlap and if they do, copy Load's operand to a new buffer.
-
- // First, create new blocks for 2n part of the check and the copy.
- BasicBlock *Check0 = MatMul->getParent();
- // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
- // DT. Manually collect dominator tree updates, to avoid unnecessary work,
- // as we adjust Check0 and Check1's branches.
- SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
- for (BasicBlock *Succ : successors(Check0))
+
+ // If we can statically determine noalias we're good.
+ if (!LdAliased)
+ return Load->getPointerOperand();
+
+ // Create code to check if the memory locations of the Load and Store
+ // overlap and if they do, copy Load's operand to a new buffer.
+
+ // First, create new blocks for 2n part of the check and the copy.
+ BasicBlock *Check0 = MatMul->getParent();
+ // FIXME: Use lazy DTU and update SplitBlock to accept a DTU instead of a
+ // DT. Manually collect dominator tree updates, to avoid unnecessary work,
+ // as we adjust Check0 and Check1's branches.
+ SmallVector<DominatorTree::UpdateType, 4> DTUpdates;
+ for (BasicBlock *Succ : successors(Check0))
DTUpdates.push_back({DT->Delete, Check0, Succ});
-
+
BasicBlock *Check1 =
SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
nullptr, "alias_cont");
- BasicBlock *Copy =
+ BasicBlock *Copy =
SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
nullptr, "copy");
BasicBlock *Fusion =
SplitBlock(MatMul->getParent(), MatMul, (DomTreeUpdater *)nullptr, LI,
nullptr, "no_alias");
-
- // Check if the loaded memory location begins before the end of the store
- // location. If the condition holds, they might overlap, otherwise they are
- // guaranteed to not overlap.
- IRBuilder<> Builder(MatMul);
- Check0->getTerminator()->eraseFromParent();
- Builder.SetInsertPoint(Check0);
- Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
- Value *StoreBegin = Builder.CreatePtrToInt(
- const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
- Value *StoreEnd = Builder.CreateAdd(
- StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
- "store.end", true, true);
- Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
- IntPtrTy, "load.begin");
- Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
- Fusion);
-
- // Check if the store begins before the end of the load location. If the
- // condition holds, they alias, otherwise they are guaranteed to not
- // overlap.
- Check1->getTerminator()->eraseFromParent();
- Builder.SetInsertPoint(Check1, Check1->begin());
- Value *LoadEnd = Builder.CreateAdd(
- LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
- "load.end", true, true);
- Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
- Fusion);
-
- // Copy load operand to new alloca.
- Builder.SetInsertPoint(Copy, Copy->begin());
- AllocaInst *NewLd =
- Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
- Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
- Load->getPointerOperand(), Load->getAlign(),
- LoadLoc.Size.getValue());
- Builder.SetInsertPoint(Fusion, Fusion->begin());
- PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
- PHI->addIncoming(Load->getPointerOperand(), Check0);
- PHI->addIncoming(Load->getPointerOperand(), Check1);
- PHI->addIncoming(NewLd, Copy);
-
- // Adjust DT.
+
+ // Check if the loaded memory location begins before the end of the store
+ // location. If the condition holds, they might overlap, otherwise they are
+ // guaranteed to not overlap.
+ IRBuilder<> Builder(MatMul);
+ Check0->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(Check0);
+ Type *IntPtrTy = Builder.getIntPtrTy(Load->getModule()->getDataLayout());
+ Value *StoreBegin = Builder.CreatePtrToInt(
+ const_cast<Value *>(StoreLoc.Ptr), IntPtrTy, "store.begin");
+ Value *StoreEnd = Builder.CreateAdd(
+ StoreBegin, ConstantInt::get(IntPtrTy, StoreLoc.Size.getValue()),
+ "store.end", true, true);
+ Value *LoadBegin = Builder.CreatePtrToInt(const_cast<Value *>(LoadLoc.Ptr),
+ IntPtrTy, "load.begin");
+ Builder.CreateCondBr(Builder.CreateICmpULT(LoadBegin, StoreEnd), Check1,
+ Fusion);
+
+ // Check if the store begins before the end of the load location. If the
+ // condition holds, they alias, otherwise they are guaranteed to not
+ // overlap.
+ Check1->getTerminator()->eraseFromParent();
+ Builder.SetInsertPoint(Check1, Check1->begin());
+ Value *LoadEnd = Builder.CreateAdd(
+ LoadBegin, ConstantInt::get(IntPtrTy, LoadLoc.Size.getValue()),
+ "load.end", true, true);
+ Builder.CreateCondBr(Builder.CreateICmpULT(StoreBegin, LoadEnd), Copy,
+ Fusion);
+
+ // Copy load operand to new alloca.
+ Builder.SetInsertPoint(Copy, Copy->begin());
+ AllocaInst *NewLd =
+ Builder.CreateAlloca(Load->getType(), Load->getPointerAddressSpace());
+ Builder.CreateMemCpy(NewLd, NewLd->getAlign(),
+ Load->getPointerOperand(), Load->getAlign(),
+ LoadLoc.Size.getValue());
+ Builder.SetInsertPoint(Fusion, Fusion->begin());
+ PHINode *PHI = Builder.CreatePHI(Load->getPointerOperandType(), 3);
+ PHI->addIncoming(Load->getPointerOperand(), Check0);
+ PHI->addIncoming(Load->getPointerOperand(), Check1);
+ PHI->addIncoming(NewLd, Copy);
+
+ // Adjust DT.
DTUpdates.push_back({DT->Insert, Check0, Check1});
DTUpdates.push_back({DT->Insert, Check0, Fusion});
DTUpdates.push_back({DT->Insert, Check1, Copy});
DTUpdates.push_back({DT->Insert, Check1, Fusion});
DT->applyUpdates(DTUpdates);
- return PHI;
- }
-
- bool isFusionProfitable(CallInst *MatMul) {
- if (ForceFusion)
- return true;
-
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
-
- const unsigned R = LShape.NumRows;
- const unsigned C = RShape.NumColumns;
- const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
-
- const unsigned VF =
- std::max<unsigned>(TTI.getRegisterBitWidth(true) /
- EltType->getPrimitiveSizeInBits().getFixedSize(),
- 1U);
-
- // Cost model for tiling
- //
- // For tiling to be beneficial, we need reuse either along the R or
- // the C axis. We vectorize along the R axis so that means at least
- // 3 elements.
- // TODO: Also consider cost of copying if operands alias.
- if (R <= VF && C == 1)
- return false;
- // Then we need enough elements to exceed the number of vector
- // registers we have. Note that this is an oversimplification since
- // fusing also takes some extra loads which may exceed the number of
- // reloads necessary.
- unsigned Op0Regs = (R + VF - 1) / VF * M;
- unsigned Op1Regs = (M + VF - 1) / VF * C;
- return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true);
- }
-
- MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
- MatrixTy Res;
- auto *ColumType = FixedVectorType::get(EltType, R);
- for (unsigned I = 0; I < C; ++I)
- Res.addVector(ConstantAggregateZero::get(ColumType));
- return Res;
- }
-
+ return PHI;
+ }
+
+ bool isFusionProfitable(CallInst *MatMul) {
+ if (ForceFusion)
+ return true;
+
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ const unsigned M = LShape.NumColumns;
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+
+ const unsigned VF =
+ std::max<unsigned>(TTI.getRegisterBitWidth(true) /
+ EltType->getPrimitiveSizeInBits().getFixedSize(),
+ 1U);
+
+ // Cost model for tiling
+ //
+ // For tiling to be beneficial, we need reuse either along the R or
+ // the C axis. We vectorize along the R axis so that means at least
+ // 3 elements.
+ // TODO: Also consider cost of copying if operands alias.
+ if (R <= VF && C == 1)
+ return false;
+ // Then we need enough elements to exceed the number of vector
+ // registers we have. Note that this is an oversimplification since
+ // fusing also takes some extra loads which may exceed the number of
+ // reloads necessary.
+ unsigned Op0Regs = (R + VF - 1) / VF * M;
+ unsigned Op1Regs = (M + VF - 1) / VF * C;
+ return Op0Regs + Op1Regs > TTI.getNumberOfRegisters(true);
+ }
+
+ MatrixTy getZeroMatrix(Type *EltType, unsigned R, unsigned C) {
+ MatrixTy Res;
+ auto *ColumType = FixedVectorType::get(EltType, R);
+ for (unsigned I = 0; I < C; ++I)
+ Res.addVector(ConstantAggregateZero::get(ColumType));
+ return Res;
+ }
+
void createTiledLoops(CallInst *MatMul, Value *LPtr, ShapeInfo LShape,
Value *RPtr, ShapeInfo RShape, StoreInst *Store,
bool AllowContract) {
@@ -1266,28 +1266,28 @@ public:
"llvm.loop.unroll.count", InnerLoopUnrollCount);
}
- void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
- StoreInst *Store,
- SmallPtrSetImpl<Instruction *> &FusedInsts) {
- assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
- "Tiling only supported for column-major matrixes at the moment!");
- if (!isFusionProfitable(MatMul))
- return;
-
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
-
- const unsigned R = LShape.NumRows;
- const unsigned C = RShape.NumColumns;
- const unsigned M = LShape.NumColumns;
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
-
- Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
- Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
- Value *CPtr = Store->getPointerOperand();
-
- bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
- MatMul->hasAllowContract());
+ void emitSIMDTiling(CallInst *MatMul, LoadInst *LoadOp0, LoadInst *LoadOp1,
+ StoreInst *Store,
+ SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ assert(MatrixLayout == MatrixLayoutTy::ColumnMajor &&
+ "Tiling only supported for column-major matrixes at the moment!");
+ if (!isFusionProfitable(MatMul))
+ return;
+
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ const unsigned M = LShape.NumColumns;
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+
+ Value *APtr = getNonAliasingPointer(LoadOp0, Store, MatMul);
+ Value *BPtr = getNonAliasingPointer(LoadOp1, Store, MatMul);
+ Value *CPtr = Store->getPointerOperand();
+
+ bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
+ MatMul->hasAllowContract());
if (TileUseLoops && (R % TileSize == 0 && C % TileSize == 0))
createTiledLoops(MatMul, APtr, LShape, BPtr, RShape, Store,
AllowContract);
@@ -1298,7 +1298,7 @@ public:
const unsigned TileR = std::min(R - I, unsigned(TileSize));
const unsigned TileC = std::min(C - J, unsigned(TileSize));
MatrixTy Res = getZeroMatrix(EltType, TileR, TileC);
-
+
for (unsigned K = 0; K < M; K += TileSize) {
const unsigned TileM = std::min(M - K, unsigned(TileSize));
MatrixTy A =
@@ -1314,192 +1314,192 @@ public:
storeMatrix(Res, CPtr, Store->getAlign(), Store->isVolatile(), {R, M},
Builder.getInt64(I), Builder.getInt64(J), EltType,
Builder);
- }
- }
-
- // Mark eliminated instructions as fused and remove them.
- FusedInsts.insert(Store);
- FusedInsts.insert(MatMul);
- Store->eraseFromParent();
- MatMul->eraseFromParent();
- if (LoadOp0->hasNUses(0)) {
- FusedInsts.insert(LoadOp0);
- LoadOp0->eraseFromParent();
+ }
}
- if (LoadOp1->hasNUses(0)) {
- FusedInsts.insert(LoadOp1);
- LoadOp1->eraseFromParent();
- }
- }
-
- /// Try to lower matrix multiply chains by fusing operations.
- ///
- /// Currently we only lower {ld, ld} -> matmul -> st chains.
- //
- /// No need to return a MatrixTy object for the result of the operation, since
- /// the single store user will be lowered as part of this. Instructions that
- /// are completely eliminated by fusion are added to \p FusedInsts.
- void LowerMatrixMultiplyFused(CallInst *MatMul,
- SmallPtrSetImpl<Instruction *> &FusedInsts) {
- if (!FuseMatrix || !MatMul->hasOneUse() ||
+
+ // Mark eliminated instructions as fused and remove them.
+ FusedInsts.insert(Store);
+ FusedInsts.insert(MatMul);
+ Store->eraseFromParent();
+ MatMul->eraseFromParent();
+ if (LoadOp0->hasNUses(0)) {
+ FusedInsts.insert(LoadOp0);
+ LoadOp0->eraseFromParent();
+ }
+ if (LoadOp1->hasNUses(0)) {
+ FusedInsts.insert(LoadOp1);
+ LoadOp1->eraseFromParent();
+ }
+ }
+
+ /// Try to lower matrix multiply chains by fusing operations.
+ ///
+ /// Currently we only lower {ld, ld} -> matmul -> st chains.
+ //
+ /// No need to return a MatrixTy object for the result of the operation, since
+ /// the single store user will be lowered as part of this. Instructions that
+ /// are completely eliminated by fusion are added to \p FusedInsts.
+ void LowerMatrixMultiplyFused(CallInst *MatMul,
+ SmallPtrSetImpl<Instruction *> &FusedInsts) {
+ if (!FuseMatrix || !MatMul->hasOneUse() ||
MatrixLayout != MatrixLayoutTy::ColumnMajor || !DT)
- return;
-
+ return;
+
assert(AA && LI && "Analyses should be available");
- auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
- auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1));
- auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
- if (LoadOp0 && LoadOp1 && Store) {
- // The store address must dominate the MatMul instruction, otherwise
- // we create invalid IR.
- // FIXME: See if we can hoist the store address computation.
- auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
+ auto *LoadOp0 = dyn_cast<LoadInst>(MatMul->getOperand(0));
+ auto *LoadOp1 = dyn_cast<LoadInst>(MatMul->getOperand(1));
+ auto *Store = dyn_cast<StoreInst>(*MatMul->user_begin());
+ if (LoadOp0 && LoadOp1 && Store) {
+ // The store address must dominate the MatMul instruction, otherwise
+ // we create invalid IR.
+ // FIXME: See if we can hoist the store address computation.
+ auto *AddrI = dyn_cast<Instruction>(Store->getOperand(1));
if (AddrI && (!DT->dominates(AddrI, MatMul)))
- return;
-
- emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
- return;
- }
- }
-
- /// Lowers llvm.matrix.multiply.
- void LowerMultiply(CallInst *MatMul) {
- IRBuilder<> Builder(MatMul);
- auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
- ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
- ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
-
- const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
- const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
+ return;
+
+ emitSIMDTiling(MatMul, LoadOp0, LoadOp1, Store, FusedInsts);
+ return;
+ }
+ }
+
+ /// Lowers llvm.matrix.multiply.
+ void LowerMultiply(CallInst *MatMul) {
+ IRBuilder<> Builder(MatMul);
+ auto *EltType = cast<VectorType>(MatMul->getType())->getElementType();
+ ShapeInfo LShape(MatMul->getArgOperand(2), MatMul->getArgOperand(3));
+ ShapeInfo RShape(MatMul->getArgOperand(3), MatMul->getArgOperand(4));
+
+ const MatrixTy &Lhs = getMatrix(MatMul->getArgOperand(0), LShape, Builder);
+ const MatrixTy &Rhs = getMatrix(MatMul->getArgOperand(1), RShape, Builder);
assert(Lhs.getElementType() == Rhs.getElementType() &&
"Matrix multiply argument element types do not match.");
-
- const unsigned R = LShape.NumRows;
- const unsigned C = RShape.NumColumns;
- assert(LShape.NumColumns == RShape.NumRows);
-
- // Initialize the output
- MatrixTy Result(R, C, EltType);
+
+ const unsigned R = LShape.NumRows;
+ const unsigned C = RShape.NumColumns;
+ assert(LShape.NumColumns == RShape.NumRows);
+
+ // Initialize the output
+ MatrixTy Result(R, C, EltType);
assert(Lhs.getElementType() == Result.getElementType() &&
"Matrix multiply result element type does not match arguments.");
-
- bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
- MatMul->hasAllowContract());
- emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
- finalizeLowering(MatMul, Result, Builder);
- }
-
- /// Lowers llvm.matrix.transpose.
- void LowerTranspose(CallInst *Inst) {
- MatrixTy Result;
- IRBuilder<> Builder(Inst);
- Value *InputVal = Inst->getArgOperand(0);
- VectorType *VectorTy = cast<VectorType>(InputVal->getType());
- ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
- MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
-
- const unsigned NewNumVecs =
- InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
- const unsigned NewNumElts =
- InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
-
- for (unsigned I = 0; I < NewNumVecs; ++I) {
- // Build a single result vector. First initialize it.
- Value *ResultVector = UndefValue::get(
- FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
- // Go through the old elements and insert it into the resulting vector.
- for (auto J : enumerate(InputMatrix.vectors())) {
- Value *Elt = Builder.CreateExtractElement(J.value(), I);
- // Row and column indices are transposed.
- ResultVector =
- Builder.CreateInsertElement(ResultVector, Elt, J.index());
- }
- Result.addVector(ResultVector);
- }
-
- // TODO: Improve estimate of operations needed for transposes. Currently we
- // just count the insertelement/extractelement instructions, but do not
- // account for later simplifications/combines.
- finalizeLowering(
- Inst,
- Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
- Builder);
- }
-
- /// Lower load instructions, if shape information is available.
- bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
- auto I = ShapeMap.find(Inst);
- if (I == ShapeMap.end())
- return false;
-
- LowerLoad(Inst, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
- }
-
- bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
- IRBuilder<> &Builder) {
- auto I = ShapeMap.find(StoredVal);
- if (I == ShapeMap.end())
- return false;
-
- LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
- Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
- I->second);
- return true;
- }
-
- /// Lower binary operators, if shape information is available.
- bool VisitBinaryOperator(BinaryOperator *Inst) {
- auto I = ShapeMap.find(Inst);
- if (I == ShapeMap.end())
- return false;
-
- Value *Lhs = Inst->getOperand(0);
- Value *Rhs = Inst->getOperand(1);
-
- IRBuilder<> Builder(Inst);
- ShapeInfo &Shape = I->second;
-
- MatrixTy Result;
- MatrixTy A = getMatrix(Lhs, Shape, Builder);
- MatrixTy B = getMatrix(Rhs, Shape, Builder);
- assert(A.isColumnMajor() == B.isColumnMajor() &&
- Result.isColumnMajor() == A.isColumnMajor() &&
- "operands must agree on matrix layout");
-
- // Helper to perform binary op on vectors.
- auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
- switch (Inst->getOpcode()) {
- case Instruction::Add:
- return Builder.CreateAdd(LHS, RHS);
- case Instruction::Mul:
- return Builder.CreateMul(LHS, RHS);
- case Instruction::Sub:
- return Builder.CreateSub(LHS, RHS);
- case Instruction::FAdd:
- return Builder.CreateFAdd(LHS, RHS);
- case Instruction::FMul:
- return Builder.CreateFMul(LHS, RHS);
- case Instruction::FSub:
- return Builder.CreateFSub(LHS, RHS);
- default:
- llvm_unreachable("Unsupported binary operator for matrix");
- }
- };
-
- for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
- Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
-
- finalizeLowering(Inst,
- Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
- Result.getNumVectors()),
- Builder);
- return true;
- }
-
+
+ bool AllowContract = AllowContractEnabled || (isa<FPMathOperator>(MatMul) &&
+ MatMul->hasAllowContract());
+ emitMatrixMultiply(Result, Lhs, Rhs, AllowContract, Builder, false);
+ finalizeLowering(MatMul, Result, Builder);
+ }
+
+ /// Lowers llvm.matrix.transpose.
+ void LowerTranspose(CallInst *Inst) {
+ MatrixTy Result;
+ IRBuilder<> Builder(Inst);
+ Value *InputVal = Inst->getArgOperand(0);
+ VectorType *VectorTy = cast<VectorType>(InputVal->getType());
+ ShapeInfo ArgShape(Inst->getArgOperand(1), Inst->getArgOperand(2));
+ MatrixTy InputMatrix = getMatrix(InputVal, ArgShape, Builder);
+
+ const unsigned NewNumVecs =
+ InputMatrix.isColumnMajor() ? ArgShape.NumRows : ArgShape.NumColumns;
+ const unsigned NewNumElts =
+ InputMatrix.isColumnMajor() ? ArgShape.NumColumns : ArgShape.NumRows;
+
+ for (unsigned I = 0; I < NewNumVecs; ++I) {
+ // Build a single result vector. First initialize it.
+ Value *ResultVector = UndefValue::get(
+ FixedVectorType::get(VectorTy->getElementType(), NewNumElts));
+ // Go through the old elements and insert it into the resulting vector.
+ for (auto J : enumerate(InputMatrix.vectors())) {
+ Value *Elt = Builder.CreateExtractElement(J.value(), I);
+ // Row and column indices are transposed.
+ ResultVector =
+ Builder.CreateInsertElement(ResultVector, Elt, J.index());
+ }
+ Result.addVector(ResultVector);
+ }
+
+ // TODO: Improve estimate of operations needed for transposes. Currently we
+ // just count the insertelement/extractelement instructions, but do not
+ // account for later simplifications/combines.
+ finalizeLowering(
+ Inst,
+ Result.addNumComputeOps(2 * ArgShape.NumRows * ArgShape.NumColumns),
+ Builder);
+ }
+
+ /// Lower load instructions, if shape information is available.
+ bool VisitLoad(LoadInst *Inst, Value *Ptr, IRBuilder<> &Builder) {
+ auto I = ShapeMap.find(Inst);
+ if (I == ShapeMap.end())
+ return false;
+
+ LowerLoad(Inst, Ptr, Inst->getAlign(),
+ Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+ I->second);
+ return true;
+ }
+
+ bool VisitStore(StoreInst *Inst, Value *StoredVal, Value *Ptr,
+ IRBuilder<> &Builder) {
+ auto I = ShapeMap.find(StoredVal);
+ if (I == ShapeMap.end())
+ return false;
+
+ LowerStore(Inst, StoredVal, Ptr, Inst->getAlign(),
+ Builder.getInt64(I->second.getStride()), Inst->isVolatile(),
+ I->second);
+ return true;
+ }
+
+ /// Lower binary operators, if shape information is available.
+ bool VisitBinaryOperator(BinaryOperator *Inst) {
+ auto I = ShapeMap.find(Inst);
+ if (I == ShapeMap.end())
+ return false;
+
+ Value *Lhs = Inst->getOperand(0);
+ Value *Rhs = Inst->getOperand(1);
+
+ IRBuilder<> Builder(Inst);
+ ShapeInfo &Shape = I->second;
+
+ MatrixTy Result;
+ MatrixTy A = getMatrix(Lhs, Shape, Builder);
+ MatrixTy B = getMatrix(Rhs, Shape, Builder);
+ assert(A.isColumnMajor() == B.isColumnMajor() &&
+ Result.isColumnMajor() == A.isColumnMajor() &&
+ "operands must agree on matrix layout");
+
+ // Helper to perform binary op on vectors.
+ auto BuildVectorOp = [&Builder, Inst](Value *LHS, Value *RHS) {
+ switch (Inst->getOpcode()) {
+ case Instruction::Add:
+ return Builder.CreateAdd(LHS, RHS);
+ case Instruction::Mul:
+ return Builder.CreateMul(LHS, RHS);
+ case Instruction::Sub:
+ return Builder.CreateSub(LHS, RHS);
+ case Instruction::FAdd:
+ return Builder.CreateFAdd(LHS, RHS);
+ case Instruction::FMul:
+ return Builder.CreateFMul(LHS, RHS);
+ case Instruction::FSub:
+ return Builder.CreateFSub(LHS, RHS);
+ default:
+ llvm_unreachable("Unsupported binary operator for matrix");
+ }
+ };
+
+ for (unsigned I = 0; I < Shape.getNumVectors(); ++I)
+ Result.addVector(BuildVectorOp(A.getVector(I), B.getVector(I)));
+
+ finalizeLowering(Inst,
+ Result.addNumComputeOps(getNumOps(Result.getVectorTy()) *
+ Result.getNumVectors()),
+ Builder);
+ return true;
+ }
+
/// Lower unary operators, if shape information is available.
bool VisitUnaryOperator(UnaryOperator *Inst) {
auto I = ShapeMap.find(Inst);
@@ -1534,449 +1534,449 @@ public:
return true;
}
- /// Helper to linearize a matrix expression tree into a string. Currently
- /// matrix expressions are linarized by starting at an expression leaf and
- /// linearizing bottom up.
- struct ExprLinearizer {
- unsigned LengthToBreak = 100;
- std::string Str;
- raw_string_ostream Stream;
- unsigned LineLength = 0;
- const DataLayout &DL;
-
- /// Mapping from instructions to matrixes. It is used to identify
- /// matrix instructions.
- const MapVector<Value *, MatrixTy> &Inst2Matrix;
-
- /// Mapping from values to the leaves of all expressions that the value is
- /// part of.
- const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
-
- /// Set of matrix expressions in the scope of a given DISubprogram.
- const SmallSetVector<Value *, 32> &ExprsInSubprogram;
-
- /// Leaf node of the expression to linearize.
- Value *Leaf;
-
- /// Used to keep track of sub-expressions that get reused while linearizing
- /// the expression. Re-used sub-expressions are marked as (reused).
- SmallPtrSet<Value *, 8> ReusedExprs;
-
- ExprLinearizer(const DataLayout &DL,
- const MapVector<Value *, MatrixTy> &Inst2Matrix,
- const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- Value *Leaf)
- : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
- ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
-
- void indent(unsigned N) {
- LineLength += N;
- for (unsigned i = 0; i < N; i++)
- Stream << " ";
- }
-
- void lineBreak() {
- Stream << "\n";
- LineLength = 0;
- }
-
- void maybeIndent(unsigned Indent) {
- if (LineLength >= LengthToBreak)
- lineBreak();
-
- if (LineLength == 0)
- indent(Indent);
- }
-
- void write(StringRef S) {
- LineLength += S.size();
- Stream << S;
- }
-
- Value *getUnderlyingObjectThroughLoads(Value *V) {
- if (Value *Ptr = getPointerOperand(V))
- return getUnderlyingObjectThroughLoads(Ptr);
- else if (V->getType()->isPointerTy())
+ /// Helper to linearize a matrix expression tree into a string. Currently
+ /// matrix expressions are linarized by starting at an expression leaf and
+ /// linearizing bottom up.
+ struct ExprLinearizer {
+ unsigned LengthToBreak = 100;
+ std::string Str;
+ raw_string_ostream Stream;
+ unsigned LineLength = 0;
+ const DataLayout &DL;
+
+ /// Mapping from instructions to matrixes. It is used to identify
+ /// matrix instructions.
+ const MapVector<Value *, MatrixTy> &Inst2Matrix;
+
+ /// Mapping from values to the leaves of all expressions that the value is
+ /// part of.
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared;
+
+ /// Set of matrix expressions in the scope of a given DISubprogram.
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram;
+
+ /// Leaf node of the expression to linearize.
+ Value *Leaf;
+
+ /// Used to keep track of sub-expressions that get reused while linearizing
+ /// the expression. Re-used sub-expressions are marked as (reused).
+ SmallPtrSet<Value *, 8> ReusedExprs;
+
+ ExprLinearizer(const DataLayout &DL,
+ const MapVector<Value *, MatrixTy> &Inst2Matrix,
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ Value *Leaf)
+ : Str(), Stream(Str), DL(DL), Inst2Matrix(Inst2Matrix), Shared(Shared),
+ ExprsInSubprogram(ExprsInSubprogram), Leaf(Leaf) {}
+
+ void indent(unsigned N) {
+ LineLength += N;
+ for (unsigned i = 0; i < N; i++)
+ Stream << " ";
+ }
+
+ void lineBreak() {
+ Stream << "\n";
+ LineLength = 0;
+ }
+
+ void maybeIndent(unsigned Indent) {
+ if (LineLength >= LengthToBreak)
+ lineBreak();
+
+ if (LineLength == 0)
+ indent(Indent);
+ }
+
+ void write(StringRef S) {
+ LineLength += S.size();
+ Stream << S;
+ }
+
+ Value *getUnderlyingObjectThroughLoads(Value *V) {
+ if (Value *Ptr = getPointerOperand(V))
+ return getUnderlyingObjectThroughLoads(Ptr);
+ else if (V->getType()->isPointerTy())
return getUnderlyingObject(V);
- return V;
- }
-
- /// Returns true if \p V is a matrix value in the given subprogram.
- bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
-
- /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
- /// \p SS.
- void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
- auto M = Inst2Matrix.find(V);
- if (M == Inst2Matrix.end())
- SS << "unknown";
- else {
- SS << M->second.getNumRows();
- SS << "x";
- SS << M->second.getNumColumns();
- }
- }
-
- /// Write the called function name. Handles calls to llvm.matrix.*
- /// specially: we write the name, followed by the dimensions of the input
- /// matrixes, followed by the scalar type name.
- void writeFnName(CallInst *CI) {
- if (!CI->getCalledFunction())
- write("<no called fn>");
- else {
- StringRef Name = CI->getCalledFunction()->getName();
- if (!Name.startswith("llvm.matrix")) {
- write(Name);
- return;
- }
- IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
- write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
- .drop_front(StringRef("llvm.matrix.").size()));
- write(".");
+ return V;
+ }
+
+ /// Returns true if \p V is a matrix value in the given subprogram.
+ bool isMatrix(Value *V) const { return ExprsInSubprogram.count(V); }
+
+ /// If \p V is a matrix value, print its shape as as NumRows x NumColumns to
+ /// \p SS.
+ void prettyPrintMatrixType(Value *V, raw_string_ostream &SS) {
+ auto M = Inst2Matrix.find(V);
+ if (M == Inst2Matrix.end())
+ SS << "unknown";
+ else {
+ SS << M->second.getNumRows();
+ SS << "x";
+ SS << M->second.getNumColumns();
+ }
+ }
+
+ /// Write the called function name. Handles calls to llvm.matrix.*
+ /// specially: we write the name, followed by the dimensions of the input
+ /// matrixes, followed by the scalar type name.
+ void writeFnName(CallInst *CI) {
+ if (!CI->getCalledFunction())
+ write("<no called fn>");
+ else {
+ StringRef Name = CI->getCalledFunction()->getName();
+ if (!Name.startswith("llvm.matrix")) {
+ write(Name);
+ return;
+ }
+ IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
+ write(StringRef(Intrinsic::getName(II->getIntrinsicID(), {}))
+ .drop_front(StringRef("llvm.matrix.").size()));
+ write(".");
std::string Tmp;
- raw_string_ostream SS(Tmp);
-
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- prettyPrintMatrixType(II->getOperand(0), SS);
- SS << ".";
- prettyPrintMatrixType(II->getOperand(1), SS);
- SS << "." << *II->getType()->getScalarType();
- break;
- case Intrinsic::matrix_transpose:
- prettyPrintMatrixType(II->getOperand(0), SS);
- SS << "." << *II->getType()->getScalarType();
- break;
- case Intrinsic::matrix_column_major_load:
- prettyPrintMatrixType(II, SS);
- SS << "." << *II->getType()->getScalarType();
- break;
- case Intrinsic::matrix_column_major_store:
- prettyPrintMatrixType(II->getOperand(0), SS);
- SS << "." << *II->getOperand(0)->getType()->getScalarType();
- break;
- default:
- llvm_unreachable("Unhandled case");
- }
- SS.flush();
- write(Tmp);
- }
- }
-
- unsigned getNumShapeArgs(CallInst *CI) const {
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
- switch (II->getIntrinsicID()) {
- case Intrinsic::matrix_multiply:
- return 3;
- case Intrinsic::matrix_transpose:
- return 2;
- case Intrinsic::matrix_column_major_load:
- case Intrinsic::matrix_column_major_store:
- return 3;
- default:
- return 0;
- }
- }
- return 0;
- }
-
- /// Special printing for values: for pointers, we print if they refer to an
- /// (function) external address or a stack address, for other values we
- /// either print the constant or "scalar"/"matrix" for other values.
- void write(Value *V) {
- V = getUnderlyingObjectThroughLoads(V);
- if (V->getType()->isPointerTy()) {
- if (isa<AllocaInst>(V)) {
- Stream << "stack addr";
- LineLength += StringRef("stack addr").size();
- } else {
- Stream << "addr";
- LineLength += StringRef("addr").size();
- }
- if (!V->getName().empty()) {
- Stream << " %" << V->getName() << "";
- LineLength += V->getName().size() + 2;
- }
- return;
- }
-
- std::string Tmp;
- raw_string_ostream TmpStream(Tmp);
-
- if (auto *CI = dyn_cast<ConstantInt>(V))
- TmpStream << CI->getValue();
- else if (isa<Constant>(V))
- TmpStream << "constant";
- else {
- if (isMatrix(V))
- TmpStream << "matrix";
- else
- TmpStream << "scalar";
- }
- TmpStream.flush();
- Tmp = std::string(StringRef(Tmp).trim());
- LineLength += Tmp.size();
- Stream << Tmp;
- }
-
- /// Linearize expression \p Expr starting at an indentation of \p Indent.
- /// Expressions that are re-used multiple times are prefixed with (reused)
- /// at the re-used root instruction.
- void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
- bool ParentShared) {
- auto *I = cast<Instruction>(Expr);
- maybeIndent(Indent);
- SmallVector<Value *, 8> Ops;
-
- // Is Expr shared with other expression leaves?
- bool ExprShared = false;
-
- // Deal with shared subtrees. Mark them as shared, if required.
- if (!ParentShared) {
- auto SI = Shared.find(Expr);
- assert(SI != Shared.end() && SI->second.count(Leaf));
-
- for (Value *S : SI->second) {
- if (S == Leaf)
- continue;
- DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
- write("shared with remark at line " + std::to_string(DL.getLine()) +
- " column " + std::to_string(DL.getCol()) + " (");
- }
- ExprShared = SI->second.size() > 1;
- }
-
- bool Reused = !ReusedExprs.insert(Expr).second;
- if (Reused && !ParentReused)
- write("(reused) ");
-
- if (auto *CI = dyn_cast<CallInst>(I)) {
- writeFnName(CI);
-
- Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
- } else if (isa<BitCastInst>(Expr)) {
- // Special case bitcasts, which are used to materialize matrixes from
- // non-matrix ops.
- write("matrix");
- return;
- } else {
- Ops.append(I->value_op_begin(), I->value_op_end());
- write(std::string(I->getOpcodeName()));
- }
-
- write(std::string("("));
-
- unsigned NumOpsToBreak = 1;
- if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
- NumOpsToBreak = 2;
-
- for (Value *Op : Ops) {
- if (Ops.size() > NumOpsToBreak)
- lineBreak();
-
- maybeIndent(Indent + 1);
- if (isMatrix(Op))
- linearizeExpr(Op, Indent + 1, Reused, ExprShared);
- else
- write(Op);
- if (Op != Ops.back())
- write(", ");
- }
-
- write(")");
- }
-
- const std::string &getResult() {
- Stream.flush();
- return Str;
- }
- };
-
- /// Generate remarks for matrix operations in a function. To generate remarks
- /// for matrix expressions, the following approach is used:
- /// 1. Use the inlined-at debug information to group matrix operations to the
- /// DISubprograms they are contained in.
- /// 2. Collect leaves of matrix expressions (done in
- /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
- // mapping. Leaves are lowered matrix instructions without other matrix
- // users (like stores) in the current subprogram.
- /// 3. For each leaf, create a remark containing a linearizied version of the
- /// matrix expression. The expression is linearized by a recursive
- /// bottom-up traversal of the matrix operands, starting at a leaf. Note
- /// that multiple leaves can share sub-expressions. Shared subexpressions
- /// are explicitly marked as shared().
- struct RemarkGenerator {
- const MapVector<Value *, MatrixTy> &Inst2Matrix;
- OptimizationRemarkEmitter &ORE;
- Function &Func;
- const DataLayout &DL;
-
- RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
- OptimizationRemarkEmitter &ORE, Function &Func)
- : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
- DL(Func.getParent()->getDataLayout()) {}
-
- /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
- /// instructions in Inst2Matrix returning void or without any users in
- /// \p ExprsInSubprogram. Currently that should only include stores.
- SmallVector<Value *, 4>
- getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
- SmallVector<Value *, 4> Leaves;
- for (auto *Expr : ExprsInSubprogram)
- if (Expr->getType()->isVoidTy() ||
- !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
- return ExprsInSubprogram.count(U);
- }))
- Leaves.push_back(Expr);
- return Leaves;
- }
-
- /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
- /// to all visited expressions in \p Shared. Limit the matrix operations to
- /// the ones in \p ExprsInSubprogram.
- void collectSharedInfo(Value *Leaf, Value *V,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
-
- if (!ExprsInSubprogram.count(V))
- return;
-
- auto I = Shared.insert({V, {}});
- I.first->second.insert(Leaf);
-
- for (Value *Op : cast<Instruction>(V)->operand_values())
- collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
- }
-
- /// Calculate the number of exclusive and shared op counts for expression
- /// starting at \p V. Expressions used multiple times are counted once.
- /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
- std::pair<OpInfoTy, OpInfoTy>
- sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
- if (!ExprsInSubprogram.count(Root))
- return {};
-
- // Already counted this expression. Stop.
- if (!ReusedExprs.insert(Root).second)
- return {};
-
- OpInfoTy SharedCount;
- OpInfoTy Count;
-
- auto I = Shared.find(Root);
- auto CM = Inst2Matrix.find(Root);
- if (I->second.size() == 1)
- Count = CM->second.getOpInfo();
- else
- SharedCount = CM->second.getOpInfo();
-
- for (Value *Op : cast<Instruction>(Root)->operand_values()) {
- auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
- Count += C.first;
- SharedCount += C.second;
- }
- return {Count, SharedCount};
- }
-
- void emitRemarks() {
- if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
- return;
-
- // Map matrix operations to their containting subprograms, by traversing
- // the inlinedAt chain. If the function does not have a DISubprogram, we
- // only map them to the containing function.
- MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
- for (auto &KV : Inst2Matrix) {
- if (Func.getSubprogram()) {
- auto *I = cast<Instruction>(KV.first);
- DILocation *Context = I->getDebugLoc();
- while (Context) {
- auto I =
- Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
- I.first->second.push_back(KV.first);
- Context = DebugLoc(Context).getInlinedAt();
- }
- } else {
- auto I = Subprog2Exprs.insert({nullptr, {}});
- I.first->second.push_back(KV.first);
- }
- }
- for (auto &KV : Subprog2Exprs) {
- SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
- KV.second.end());
- auto Leaves = getExpressionLeaves(ExprsInSubprogram);
-
- DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
- for (Value *Leaf : Leaves)
- collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
-
- // Generate remarks for each leaf.
- for (auto *L : Leaves) {
-
- DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
- DILocation *Context = cast<Instruction>(L)->getDebugLoc();
- while (Context) {
- if (getSubprogram(Context->getScope()) == KV.first) {
- Loc = Context;
- break;
- }
- Context = DebugLoc(Context).getInlinedAt();
- }
-
- SmallPtrSet<Value *, 8> ReusedExprs;
- OpInfoTy Counts, SharedCounts;
- std::tie(Counts, SharedCounts) =
- sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
-
- OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
- cast<Instruction>(L)->getParent());
-
- Rem << "Lowered with ";
- Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
- << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
- << ore::NV("NumComputeOps", Counts.NumComputeOps)
- << " compute ops";
-
- if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
- SharedCounts.NumComputeOps > 0) {
- Rem << ",\nadditionally "
- << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
- << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
- << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
- << " compute ops"
- << " are shared with other expressions";
- }
-
- Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
- ORE.emit(Rem);
- }
- }
- }
-
- std::string
- linearize(Value *L,
- const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
- const SmallSetVector<Value *, 32> &ExprsInSubprogram,
- const DataLayout &DL) {
- ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
- Lin.linearizeExpr(L, 0, false, false);
- return Lin.getResult();
- }
- };
-};
-} // namespace
-
-PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
- FunctionAnalysisManager &AM) {
- auto &TTI = AM.getResult<TargetIRAnalysis>(F);
+ raw_string_ostream SS(Tmp);
+
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << ".";
+ prettyPrintMatrixType(II->getOperand(1), SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_transpose:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_column_major_load:
+ prettyPrintMatrixType(II, SS);
+ SS << "." << *II->getType()->getScalarType();
+ break;
+ case Intrinsic::matrix_column_major_store:
+ prettyPrintMatrixType(II->getOperand(0), SS);
+ SS << "." << *II->getOperand(0)->getType()->getScalarType();
+ break;
+ default:
+ llvm_unreachable("Unhandled case");
+ }
+ SS.flush();
+ write(Tmp);
+ }
+ }
+
+ unsigned getNumShapeArgs(CallInst *CI) const {
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::matrix_multiply:
+ return 3;
+ case Intrinsic::matrix_transpose:
+ return 2;
+ case Intrinsic::matrix_column_major_load:
+ case Intrinsic::matrix_column_major_store:
+ return 3;
+ default:
+ return 0;
+ }
+ }
+ return 0;
+ }
+
+ /// Special printing for values: for pointers, we print if they refer to an
+ /// (function) external address or a stack address, for other values we
+ /// either print the constant or "scalar"/"matrix" for other values.
+ void write(Value *V) {
+ V = getUnderlyingObjectThroughLoads(V);
+ if (V->getType()->isPointerTy()) {
+ if (isa<AllocaInst>(V)) {
+ Stream << "stack addr";
+ LineLength += StringRef("stack addr").size();
+ } else {
+ Stream << "addr";
+ LineLength += StringRef("addr").size();
+ }
+ if (!V->getName().empty()) {
+ Stream << " %" << V->getName() << "";
+ LineLength += V->getName().size() + 2;
+ }
+ return;
+ }
+
+ std::string Tmp;
+ raw_string_ostream TmpStream(Tmp);
+
+ if (auto *CI = dyn_cast<ConstantInt>(V))
+ TmpStream << CI->getValue();
+ else if (isa<Constant>(V))
+ TmpStream << "constant";
+ else {
+ if (isMatrix(V))
+ TmpStream << "matrix";
+ else
+ TmpStream << "scalar";
+ }
+ TmpStream.flush();
+ Tmp = std::string(StringRef(Tmp).trim());
+ LineLength += Tmp.size();
+ Stream << Tmp;
+ }
+
+ /// Linearize expression \p Expr starting at an indentation of \p Indent.
+ /// Expressions that are re-used multiple times are prefixed with (reused)
+ /// at the re-used root instruction.
+ void linearizeExpr(Value *Expr, unsigned Indent, bool ParentReused,
+ bool ParentShared) {
+ auto *I = cast<Instruction>(Expr);
+ maybeIndent(Indent);
+ SmallVector<Value *, 8> Ops;
+
+ // Is Expr shared with other expression leaves?
+ bool ExprShared = false;
+
+ // Deal with shared subtrees. Mark them as shared, if required.
+ if (!ParentShared) {
+ auto SI = Shared.find(Expr);
+ assert(SI != Shared.end() && SI->second.count(Leaf));
+
+ for (Value *S : SI->second) {
+ if (S == Leaf)
+ continue;
+ DebugLoc DL = cast<Instruction>(S)->getDebugLoc();
+ write("shared with remark at line " + std::to_string(DL.getLine()) +
+ " column " + std::to_string(DL.getCol()) + " (");
+ }
+ ExprShared = SI->second.size() > 1;
+ }
+
+ bool Reused = !ReusedExprs.insert(Expr).second;
+ if (Reused && !ParentReused)
+ write("(reused) ");
+
+ if (auto *CI = dyn_cast<CallInst>(I)) {
+ writeFnName(CI);
+
+ Ops.append(CI->arg_begin(), CI->arg_end() - getNumShapeArgs(CI));
+ } else if (isa<BitCastInst>(Expr)) {
+ // Special case bitcasts, which are used to materialize matrixes from
+ // non-matrix ops.
+ write("matrix");
+ return;
+ } else {
+ Ops.append(I->value_op_begin(), I->value_op_end());
+ write(std::string(I->getOpcodeName()));
+ }
+
+ write(std::string("("));
+
+ unsigned NumOpsToBreak = 1;
+ if (match(Expr, m_Intrinsic<Intrinsic::matrix_column_major_load>()))
+ NumOpsToBreak = 2;
+
+ for (Value *Op : Ops) {
+ if (Ops.size() > NumOpsToBreak)
+ lineBreak();
+
+ maybeIndent(Indent + 1);
+ if (isMatrix(Op))
+ linearizeExpr(Op, Indent + 1, Reused, ExprShared);
+ else
+ write(Op);
+ if (Op != Ops.back())
+ write(", ");
+ }
+
+ write(")");
+ }
+
+ const std::string &getResult() {
+ Stream.flush();
+ return Str;
+ }
+ };
+
+ /// Generate remarks for matrix operations in a function. To generate remarks
+ /// for matrix expressions, the following approach is used:
+ /// 1. Use the inlined-at debug information to group matrix operations to the
+ /// DISubprograms they are contained in.
+ /// 2. Collect leaves of matrix expressions (done in
+ /// RemarkGenerator::getExpressionLeaves) for each subprogram - expression
+ // mapping. Leaves are lowered matrix instructions without other matrix
+ // users (like stores) in the current subprogram.
+ /// 3. For each leaf, create a remark containing a linearizied version of the
+ /// matrix expression. The expression is linearized by a recursive
+ /// bottom-up traversal of the matrix operands, starting at a leaf. Note
+ /// that multiple leaves can share sub-expressions. Shared subexpressions
+ /// are explicitly marked as shared().
+ struct RemarkGenerator {
+ const MapVector<Value *, MatrixTy> &Inst2Matrix;
+ OptimizationRemarkEmitter &ORE;
+ Function &Func;
+ const DataLayout &DL;
+
+ RemarkGenerator(const MapVector<Value *, MatrixTy> &Inst2Matrix,
+ OptimizationRemarkEmitter &ORE, Function &Func)
+ : Inst2Matrix(Inst2Matrix), ORE(ORE), Func(Func),
+ DL(Func.getParent()->getDataLayout()) {}
+
+ /// Return all leaves of the expressions in \p ExprsInSubprogram. Those are
+ /// instructions in Inst2Matrix returning void or without any users in
+ /// \p ExprsInSubprogram. Currently that should only include stores.
+ SmallVector<Value *, 4>
+ getExpressionLeaves(const SmallSetVector<Value *, 32> &ExprsInSubprogram) {
+ SmallVector<Value *, 4> Leaves;
+ for (auto *Expr : ExprsInSubprogram)
+ if (Expr->getType()->isVoidTy() ||
+ !any_of(Expr->users(), [&ExprsInSubprogram](User *U) {
+ return ExprsInSubprogram.count(U);
+ }))
+ Leaves.push_back(Expr);
+ return Leaves;
+ }
+
+ /// Recursively traverse expression \p V starting at \p Leaf and add \p Leaf
+ /// to all visited expressions in \p Shared. Limit the matrix operations to
+ /// the ones in \p ExprsInSubprogram.
+ void collectSharedInfo(Value *Leaf, Value *V,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) {
+
+ if (!ExprsInSubprogram.count(V))
+ return;
+
+ auto I = Shared.insert({V, {}});
+ I.first->second.insert(Leaf);
+
+ for (Value *Op : cast<Instruction>(V)->operand_values())
+ collectSharedInfo(Leaf, Op, ExprsInSubprogram, Shared);
+ }
+
+ /// Calculate the number of exclusive and shared op counts for expression
+ /// starting at \p V. Expressions used multiple times are counted once.
+ /// Limit the matrix operations to the ones in \p ExprsInSubprogram.
+ std::pair<OpInfoTy, OpInfoTy>
+ sumOpInfos(Value *Root, SmallPtrSetImpl<Value *> &ReusedExprs,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared) const {
+ if (!ExprsInSubprogram.count(Root))
+ return {};
+
+ // Already counted this expression. Stop.
+ if (!ReusedExprs.insert(Root).second)
+ return {};
+
+ OpInfoTy SharedCount;
+ OpInfoTy Count;
+
+ auto I = Shared.find(Root);
+ auto CM = Inst2Matrix.find(Root);
+ if (I->second.size() == 1)
+ Count = CM->second.getOpInfo();
+ else
+ SharedCount = CM->second.getOpInfo();
+
+ for (Value *Op : cast<Instruction>(Root)->operand_values()) {
+ auto C = sumOpInfos(Op, ReusedExprs, ExprsInSubprogram, Shared);
+ Count += C.first;
+ SharedCount += C.second;
+ }
+ return {Count, SharedCount};
+ }
+
+ void emitRemarks() {
+ if (!ORE.allowExtraAnalysis(DEBUG_TYPE))
+ return;
+
+ // Map matrix operations to their containting subprograms, by traversing
+ // the inlinedAt chain. If the function does not have a DISubprogram, we
+ // only map them to the containing function.
+ MapVector<DISubprogram *, SmallVector<Value *, 8>> Subprog2Exprs;
+ for (auto &KV : Inst2Matrix) {
+ if (Func.getSubprogram()) {
+ auto *I = cast<Instruction>(KV.first);
+ DILocation *Context = I->getDebugLoc();
+ while (Context) {
+ auto I =
+ Subprog2Exprs.insert({getSubprogram(Context->getScope()), {}});
+ I.first->second.push_back(KV.first);
+ Context = DebugLoc(Context).getInlinedAt();
+ }
+ } else {
+ auto I = Subprog2Exprs.insert({nullptr, {}});
+ I.first->second.push_back(KV.first);
+ }
+ }
+ for (auto &KV : Subprog2Exprs) {
+ SmallSetVector<Value *, 32> ExprsInSubprogram(KV.second.begin(),
+ KV.second.end());
+ auto Leaves = getExpressionLeaves(ExprsInSubprogram);
+
+ DenseMap<Value *, SmallPtrSet<Value *, 2>> Shared;
+ for (Value *Leaf : Leaves)
+ collectSharedInfo(Leaf, Leaf, ExprsInSubprogram, Shared);
+
+ // Generate remarks for each leaf.
+ for (auto *L : Leaves) {
+
+ DebugLoc Loc = cast<Instruction>(L)->getDebugLoc();
+ DILocation *Context = cast<Instruction>(L)->getDebugLoc();
+ while (Context) {
+ if (getSubprogram(Context->getScope()) == KV.first) {
+ Loc = Context;
+ break;
+ }
+ Context = DebugLoc(Context).getInlinedAt();
+ }
+
+ SmallPtrSet<Value *, 8> ReusedExprs;
+ OpInfoTy Counts, SharedCounts;
+ std::tie(Counts, SharedCounts) =
+ sumOpInfos(L, ReusedExprs, ExprsInSubprogram, Shared);
+
+ OptimizationRemark Rem(DEBUG_TYPE, "matrix-lowered", Loc,
+ cast<Instruction>(L)->getParent());
+
+ Rem << "Lowered with ";
+ Rem << ore::NV("NumStores", Counts.NumStores) << " stores, "
+ << ore::NV("NumLoads", Counts.NumLoads) << " loads, "
+ << ore::NV("NumComputeOps", Counts.NumComputeOps)
+ << " compute ops";
+
+ if (SharedCounts.NumStores > 0 || SharedCounts.NumLoads > 0 ||
+ SharedCounts.NumComputeOps > 0) {
+ Rem << ",\nadditionally "
+ << ore::NV("NumStores", SharedCounts.NumStores) << " stores, "
+ << ore::NV("NumLoads", SharedCounts.NumLoads) << " loads, "
+ << ore::NV("NumFPOps", SharedCounts.NumComputeOps)
+ << " compute ops"
+ << " are shared with other expressions";
+ }
+
+ Rem << ("\n" + linearize(L, Shared, ExprsInSubprogram, DL));
+ ORE.emit(Rem);
+ }
+ }
+ }
+
+ std::string
+ linearize(Value *L,
+ const DenseMap<Value *, SmallPtrSet<Value *, 2>> &Shared,
+ const SmallSetVector<Value *, 32> &ExprsInSubprogram,
+ const DataLayout &DL) {
+ ExprLinearizer Lin(DL, Inst2Matrix, Shared, ExprsInSubprogram, L);
+ Lin.linearizeExpr(L, 0, false, false);
+ return Lin.getResult();
+ }
+ };
+};
+} // namespace
+
+PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ auto &TTI = AM.getResult<TargetIRAnalysis>(F);
OptimizationRemarkEmitter *ORE = nullptr;
AAResults *AA = nullptr;
DominatorTree *DT = nullptr;
LoopInfo *LI = nullptr;
-
+
if (!Minimal) {
ORE = &AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
AA = &AM.getResult<AAManager>(F);
@@ -1984,66 +1984,66 @@ PreservedAnalyses LowerMatrixIntrinsicsPass::run(Function &F,
LI = &AM.getResult<LoopAnalysis>(F);
}
- LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
- if (LMT.Visit()) {
- PreservedAnalyses PA;
+ LowerMatrixIntrinsics LMT(F, TTI, AA, DT, LI, ORE);
+ if (LMT.Visit()) {
+ PreservedAnalyses PA;
if (!Minimal) {
PA.preserve<LoopAnalysis>();
PA.preserve<DominatorTreeAnalysis>();
}
- return PA;
- }
- return PreservedAnalyses::all();
-}
-
-namespace {
-
-class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
-public:
- static char ID;
-
- LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
- initializeLowerMatrixIntrinsicsLegacyPassPass(
- *PassRegistry::getPassRegistry());
- }
-
- bool runOnFunction(Function &F) override {
- auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
- auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
- auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
- auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+ return PA;
+ }
+ return PreservedAnalyses::all();
+}
+
+namespace {
+
+class LowerMatrixIntrinsicsLegacyPass : public FunctionPass {
+public:
+ static char ID;
+
+ LowerMatrixIntrinsicsLegacyPass() : FunctionPass(ID) {
+ initializeLowerMatrixIntrinsicsLegacyPassPass(
+ *PassRegistry::getPassRegistry());
+ }
+
+ bool runOnFunction(Function &F) override {
+ auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE();
+ auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
+ auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+ auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
LowerMatrixIntrinsics LMT(F, TTI, &AA, &DT, &LI, &ORE);
- bool C = LMT.Visit();
- return C;
- }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<TargetTransformInfoWrapperPass>();
- AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
- AU.addRequired<AAResultsWrapperPass>();
- AU.addRequired<DominatorTreeWrapperPass>();
- AU.addPreserved<DominatorTreeWrapperPass>();
- AU.addRequired<LoopInfoWrapperPass>();
- AU.addPreserved<LoopInfoWrapperPass>();
- }
-};
-} // namespace
-
-static const char pass_name[] = "Lower the matrix intrinsics";
-char LowerMatrixIntrinsicsLegacyPass::ID = 0;
-INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
-INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
-INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
-INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
- false, false)
-
-Pass *llvm::createLowerMatrixIntrinsicsPass() {
- return new LowerMatrixIntrinsicsLegacyPass();
-}
+ bool C = LMT.Visit();
+ return C;
+ }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override {
+ AU.addRequired<TargetTransformInfoWrapperPass>();
+ AU.addRequired<OptimizationRemarkEmitterWrapperPass>();
+ AU.addRequired<AAResultsWrapperPass>();
+ AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addPreserved<DominatorTreeWrapperPass>();
+ AU.addRequired<LoopInfoWrapperPass>();
+ AU.addPreserved<LoopInfoWrapperPass>();
+ }
+};
+} // namespace
+
+static const char pass_name[] = "Lower the matrix intrinsics";
+char LowerMatrixIntrinsicsLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
+ false, false)
+INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_END(LowerMatrixIntrinsicsLegacyPass, DEBUG_TYPE, pass_name,
+ false, false)
+
+Pass *llvm::createLowerMatrixIntrinsicsPass() {
+ return new LowerMatrixIntrinsicsLegacyPass();
+}
namespace {