aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm12/lib/Analysis/Delinearization.cpp
blob: 87a41bbf16a59d942399d29152c1bb0a77986464 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
//===---- Delinearization.cpp - MultiDimensional Index Delinearization ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This implements an analysis pass that tries to delinearize all GEP
// instructions in all loops using the SCEV analysis functionality. This pass is
// only used for testing purposes: if your pass needs delinearization, please
// use the on-demand SCEVAddRecExpr::delinearize() function.
//
//===----------------------------------------------------------------------===//

#include "llvm/Analysis/Delinearization.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/Passes.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;

#define DL_NAME "delinearize"
#define DEBUG_TYPE DL_NAME

namespace {

class Delinearization : public FunctionPass {
  Delinearization(const Delinearization &); // do not implement
protected:
  Function *F;
  LoopInfo *LI;
  ScalarEvolution *SE;

public:
  static char ID; // Pass identification, replacement for typeid

  Delinearization() : FunctionPass(ID) {
    initializeDelinearizationPass(*PassRegistry::getPassRegistry());
  }
  bool runOnFunction(Function &F) override;
  void getAnalysisUsage(AnalysisUsage &AU) const override;
  void print(raw_ostream &O, const Module *M = nullptr) const override;
};

void printDelinearization(raw_ostream &O, Function *F, LoopInfo *LI,
                          ScalarEvolution *SE) {
  O << "Delinearization on function " << F->getName() << ":\n";
  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) {
    Instruction *Inst = &(*I);

    // Only analyze loads and stores.
    if (!isa<StoreInst>(Inst) && !isa<LoadInst>(Inst) &&
        !isa<GetElementPtrInst>(Inst))
      continue;

    const BasicBlock *BB = Inst->getParent();
    // Delinearize the memory access as analyzed in all the surrounding loops.
    // Do not analyze memory accesses outside loops.
    for (Loop *L = LI->getLoopFor(BB); L != nullptr; L = L->getParentLoop()) {
      const SCEV *AccessFn = SE->getSCEVAtScope(getPointerOperand(Inst), L);

      const SCEVUnknown *BasePointer =
          dyn_cast<SCEVUnknown>(SE->getPointerBase(AccessFn));
      // Do not delinearize if we cannot find the base pointer.
      if (!BasePointer)
        break;
      AccessFn = SE->getMinusSCEV(AccessFn, BasePointer);

      O << "\n";
      O << "Inst:" << *Inst << "\n";
      O << "In Loop with Header: " << L->getHeader()->getName() << "\n";
      O << "AccessFunction: " << *AccessFn << "\n";

      SmallVector<const SCEV *, 3> Subscripts, Sizes;
      SE->delinearize(AccessFn, Subscripts, Sizes, SE->getElementSize(Inst));
      if (Subscripts.size() == 0 || Sizes.size() == 0 ||
          Subscripts.size() != Sizes.size()) {
        O << "failed to delinearize\n";
        continue;
      }

      O << "Base offset: " << *BasePointer << "\n";
      O << "ArrayDecl[UnknownSize]";
      int Size = Subscripts.size();
      for (int i = 0; i < Size - 1; i++)
        O << "[" << *Sizes[i] << "]";
      O << " with elements of " << *Sizes[Size - 1] << " bytes.\n";

      O << "ArrayRef";
      for (int i = 0; i < Size; i++)
        O << "[" << *Subscripts[i] << "]";
      O << "\n";
    }
  }
}

} // end anonymous namespace

void Delinearization::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.setPreservesAll();
  AU.addRequired<LoopInfoWrapperPass>();
  AU.addRequired<ScalarEvolutionWrapperPass>();
}

bool Delinearization::runOnFunction(Function &F) {
  this->F = &F;
  SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
  LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  return false;
}

void Delinearization::print(raw_ostream &O, const Module *) const {
  printDelinearization(O, F, LI, SE);
}

char Delinearization::ID = 0;
static const char delinearization_name[] = "Delinearization";
INITIALIZE_PASS_BEGIN(Delinearization, DL_NAME, delinearization_name, true,
                      true)
INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
INITIALIZE_PASS_END(Delinearization, DL_NAME, delinearization_name, true, true)

FunctionPass *llvm::createDelinearizationPass() { return new Delinearization; }

DelinearizationPrinterPass::DelinearizationPrinterPass(raw_ostream &OS)
    : OS(OS) {}
PreservedAnalyses DelinearizationPrinterPass::run(Function &F,
                                                  FunctionAnalysisManager &AM) {
  printDelinearization(OS, &F, &AM.getResult<LoopAnalysis>(F),
                       &AM.getResult<ScalarEvolutionAnalysis>(F));
  return PreservedAnalyses::all();
}