aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm16/lib/CodeGen/MIRSampleProfile.cpp
blob: a8996a586909c651de7b5842fd4388eb563355ae (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
//===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file provides the implementation of the MIRSampleProfile loader, mainly
// for flow sensitive SampleFDO.
//
//===----------------------------------------------------------------------===//

#include "llvm/CodeGen/MIRSampleProfile.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
#include "llvm/CodeGen/MachinePostDominators.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/Function.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
#include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"

using namespace llvm;
using namespace sampleprof;
using namespace llvm::sampleprofutil;
using ProfileCount = Function::ProfileCount;

#define DEBUG_TYPE "fs-profile-loader"

static cl::opt<bool> ShowFSBranchProb(
    "show-fs-branchprob", cl::Hidden, cl::init(false),
    cl::desc("Print setting flow sensitive branch probabilities"));
static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
    "fs-profile-debug-prob-diff-threshold", cl::init(10),
    cl::desc("Only show debug message if the branch probility is greater than "
             "this value (in percentage)."));

static cl::opt<unsigned> FSProfileDebugBWThreshold(
    "fs-profile-debug-bw-threshold", cl::init(10000),
    cl::desc("Only show debug message if the source branch weight is greater "
             " than this value."));

static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
                                   cl::init(false),
                                   cl::desc("View BFI before MIR loader"));
static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
                                  cl::init(false),
                                  cl::desc("View BFI after MIR loader"));

char MIRProfileLoaderPass::ID = 0;

INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
                      "Load MIR Sample Profile",
                      /* cfg = */ false, /* is_analysis = */ false)
INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfo)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
                    /* cfg = */ false, /* is_analysis = */ false)

char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;

FunctionPass *llvm::createMIRProfileLoaderPass(std::string File,
                                               std::string RemappingFile,
                                               FSDiscriminatorPass P) {
  return new MIRProfileLoaderPass(File, RemappingFile, P);
}

namespace llvm {

// Internal option used to control BFI display only after MBP pass.
// Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
// -view-block-layout-with-bfi={none | fraction | integer | count}
extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;

// Command line option to specify the name of the function for CFG dump
// Defined in Analysis/BlockFrequencyInfo.cpp:  -view-bfi-func-name=
extern cl::opt<std::string> ViewBlockFreqFuncName;

namespace afdo_detail {
template <> struct IRTraits<MachineBasicBlock> {
  using InstructionT = MachineInstr;
  using BasicBlockT = MachineBasicBlock;
  using FunctionT = MachineFunction;
  using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
  using LoopT = MachineLoop;
  using LoopInfoPtrT = MachineLoopInfo *;
  using DominatorTreePtrT = MachineDominatorTree *;
  using PostDominatorTreePtrT = MachinePostDominatorTree *;
  using PostDominatorTreeT = MachinePostDominatorTree;
  using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
  using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
  using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
  using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
  static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
  static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
    return GraphTraits<const MachineFunction *>::getEntryNode(F);
  }
  static PredRangeT getPredecessors(MachineBasicBlock *BB) {
    return BB->predecessors();
  }
  static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
    return BB->successors();
  }
};
} // namespace afdo_detail

class MIRProfileLoader final
    : public SampleProfileLoaderBaseImpl<MachineBasicBlock> {
public:
  void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
                   MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
                   MachineOptimizationRemarkEmitter *MORE) {
    DT = MDT;
    PDT = MPDT;
    LI = MLI;
    BFI = MBFI;
    ORE = MORE;
  }
  void setFSPass(FSDiscriminatorPass Pass) {
    P = Pass;
    LowBit = getFSPassBitBegin(P);
    HighBit = getFSPassBitEnd(P);
    assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
  }

  MIRProfileLoader(StringRef Name, StringRef RemapName)
      : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName)) {
  }

  void setBranchProbs(MachineFunction &F);
  bool runOnFunction(MachineFunction &F);
  bool doInitialization(Module &M);
  bool isValid() const { return ProfileIsValid; }

protected:
  friend class SampleCoverageTracker;

  /// Hold the information of the basic block frequency.
  MachineBlockFrequencyInfo *BFI;

  /// PassNum is the sequence number this pass is called, start from 1.
  FSDiscriminatorPass P;

  // LowBit in the FS discriminator used by this instance. Note the number is
  // 0-based. Base discrimnator use bit 0 to bit 11.
  unsigned LowBit;
  // HighwBit in the FS discriminator used by this instance. Note the number
  // is 0-based.
  unsigned HighBit;

  bool ProfileIsValid = true;
};

template <>
void SampleProfileLoaderBaseImpl<
    MachineBasicBlock>::computeDominanceAndLoopInfo(MachineFunction &F) {}

void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
  LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
  for (auto &BI : F) {
    MachineBasicBlock *BB = &BI;
    if (BB->succ_size() < 2)
      continue;
    const MachineBasicBlock *EC = EquivalenceClass[BB];
    uint64_t BBWeight = BlockWeights[EC];
    uint64_t SumEdgeWeight = 0;
    for (MachineBasicBlock *Succ : BB->successors()) {
      Edge E = std::make_pair(BB, Succ);
      SumEdgeWeight += EdgeWeights[E];
    }

    if (BBWeight != SumEdgeWeight) {
      LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
                        << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
                        << "\n");
      BBWeight = SumEdgeWeight;
    }
    if (BBWeight == 0) {
      LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
      continue;
    }

#ifndef NDEBUG
    uint64_t BBWeightOrig = BBWeight;
#endif
    uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
    uint32_t Factor = 1;
    if (BBWeight > MaxWeight) {
      Factor = BBWeight / MaxWeight + 1;
      BBWeight /= Factor;
      LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
    }

    for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
                                          SE = BB->succ_end();
         SI != SE; ++SI) {
      MachineBasicBlock *Succ = *SI;
      Edge E = std::make_pair(BB, Succ);
      uint64_t EdgeWeight = EdgeWeights[E];
      EdgeWeight /= Factor;

      assert(BBWeight >= EdgeWeight &&
             "BBweight is larger than EdgeWeight -- should not happen.\n");

      BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(BB, SI);
      BranchProbability NewProb(EdgeWeight, BBWeight);
      if (OldProb == NewProb)
        continue;
      BB->setSuccProbability(SI, NewProb);
#ifndef NDEBUG
      if (!ShowFSBranchProb)
        continue;
      bool Show = false;
      BranchProbability Diff;
      if (OldProb > NewProb)
        Diff = OldProb - NewProb;
      else
        Diff = NewProb - OldProb;
      Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
      Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);

      auto DIL = BB->findBranchDebugLoc();
      auto SuccDIL = Succ->findBranchDebugLoc();
      if (Show) {
        dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
               << Succ->getNumber() << "): ";
        if (DIL)
          dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
                 << DIL->getColumn();
        if (SuccDIL)
          dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
                 << ":" << SuccDIL->getColumn();
        dbgs() << " W=" << BBWeightOrig << "  " << OldProb << " --> " << NewProb
               << "\n";
      }
#endif
    }
  }
}

bool MIRProfileLoader::doInitialization(Module &M) {
  auto &Ctx = M.getContext();

  auto ReaderOrErr = sampleprof::SampleProfileReader::create(Filename, Ctx, P,
                                                             RemappingFilename);
  if (std::error_code EC = ReaderOrErr.getError()) {
    std::string Msg = "Could not open profile: " + EC.message();
    Ctx.diagnose(DiagnosticInfoSampleProfile(Filename, Msg));
    return false;
  }

  Reader = std::move(ReaderOrErr.get());
  Reader->setModule(&M);
  ProfileIsValid = (Reader->read() == sampleprof_error::success);
  Reader->getSummary();

  return true;
}

bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
  Function &Func = MF.getFunction();
  clearFunctionData(false);
  Samples = Reader->getSamplesFor(Func);
  if (!Samples || Samples->empty())
    return false;

  if (getFunctionLoc(MF) == 0)
    return false;

  DenseSet<GlobalValue::GUID> InlinedGUIDs;
  bool Changed = computeAndPropagateWeights(MF, InlinedGUIDs);

  // Set the new BPI, BFI.
  setBranchProbs(MF);

  return Changed;
}

} // namespace llvm

MIRProfileLoaderPass::MIRProfileLoaderPass(std::string FileName,
                                           std::string RemappingFileName,
                                           FSDiscriminatorPass P)
    : MachineFunctionPass(ID), ProfileFileName(FileName), P(P),
      MIRSampleLoader(
          std::make_unique<MIRProfileLoader>(FileName, RemappingFileName)) {
  LowBit = getFSPassBitBegin(P);
  HighBit = getFSPassBitEnd(P);
  assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
}

bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
  if (!MIRSampleLoader->isValid())
    return false;

  LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
                    << MF.getFunction().getName() << "\n");
  MBFI = &getAnalysis<MachineBlockFrequencyInfo>();
  MIRSampleLoader->setInitVals(
      &getAnalysis<MachineDominatorTree>(),
      &getAnalysis<MachinePostDominatorTree>(), &getAnalysis<MachineLoopInfo>(),
      MBFI, &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());

  MF.RenumberBlocks();
  if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
      (ViewBlockFreqFuncName.empty() ||
       MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
    MBFI->view("MIR_Prof_loader_b." + MF.getName(), false);
  }

  bool Changed = MIRSampleLoader->runOnFunction(MF);
  if (Changed)
    MBFI->calculate(MF, *MBFI->getMBPI(), *&getAnalysis<MachineLoopInfo>());

  if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
      (ViewBlockFreqFuncName.empty() ||
       MF.getFunction().getName().equals(ViewBlockFreqFuncName))) {
    MBFI->view("MIR_prof_loader_a." + MF.getName(), false);
  }

  return Changed;
}

bool MIRProfileLoaderPass::doInitialization(Module &M) {
  LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
                    << "\n");

  MIRSampleLoader->setFSPass(P);
  return MIRSampleLoader->doInitialization(M);
}

void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
  AU.setPreservesAll();
  AU.addRequired<MachineBlockFrequencyInfo>();
  AU.addRequired<MachineDominatorTree>();
  AU.addRequired<MachinePostDominatorTree>();
  AU.addRequiredTransitive<MachineLoopInfo>();
  AU.addRequired<MachineOptimizationRemarkEmitterPass>();
  MachineFunctionPass::getAnalysisUsage(AU);
}