aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm12/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp
blob: e419f454059fdb632cb1d8d4073cf0e8a53c78a0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
//==-- X86LoadValueInjectionLoadHardening.cpp - LVI load hardening for x86 --=//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// Description: This pass finds Load Value Injection (LVI) gadgets consisting
/// of a load from memory (i.e., SOURCE), and any operation that may transmit
/// the value loaded from memory over a covert channel, or use the value loaded
/// from memory to determine a branch/call target (i.e., SINK). After finding
/// all such gadgets in a given function, the pass minimally inserts LFENCE
/// instructions in such a manner that the following property is satisfied: for
/// all SOURCE+SINK pairs, all paths in the CFG from SOURCE to SINK contain at
/// least one LFENCE instruction. The algorithm that implements this minimal
/// insertion is influenced by an academic paper that minimally inserts memory
/// fences for high-performance concurrent programs:
///         http://www.cs.ucr.edu/~lesani/companion/oopsla15/OOPSLA15.pdf
/// The algorithm implemented in this pass is as follows:
/// 1. Build a condensed CFG (i.e., a GadgetGraph) consisting only of the
/// following components:
///    - SOURCE instructions (also includes function arguments)
///    - SINK instructions
///    - Basic block entry points
///    - Basic block terminators
///    - LFENCE instructions
/// 2. Analyze the GadgetGraph to determine which SOURCE+SINK pairs (i.e.,
/// gadgets) are already mitigated by existing LFENCEs. If all gadgets have been
/// mitigated, go to step 6.
/// 3. Use a heuristic or plugin to approximate minimal LFENCE insertion.
/// 4. Insert one LFENCE along each CFG edge that was cut in step 3.
/// 5. Go to step 2.
/// 6. If any LFENCEs were inserted, return `true` from runOnMachineFunction()
/// to tell LLVM that the function was modified.
///
//===----------------------------------------------------------------------===//

#include "ImmutableGraph.h"
#include "X86.h"
#include "X86Subtarget.h"
#include "X86TargetMachine.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h" 
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineDominanceFrontier.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/RDFGraph.h"
#include "llvm/CodeGen/RDFLiveness.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/DOTGraphTraits.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/DynamicLibrary.h"
#include "llvm/Support/GraphWriter.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;

#define PASS_KEY "x86-lvi-load"
#define DEBUG_TYPE PASS_KEY

STATISTIC(NumFences, "Number of LFENCEs inserted for LVI mitigation");
STATISTIC(NumFunctionsConsidered, "Number of functions analyzed");
STATISTIC(NumFunctionsMitigated, "Number of functions for which mitigations "
                                 "were deployed");
STATISTIC(NumGadgets, "Number of LVI gadgets detected during analysis");

static cl::opt<std::string> OptimizePluginPath(
    PASS_KEY "-opt-plugin",
    cl::desc("Specify a plugin to optimize LFENCE insertion"), cl::Hidden);

static cl::opt<bool> NoConditionalBranches(
    PASS_KEY "-no-cbranch",
    cl::desc("Don't treat conditional branches as disclosure gadgets. This "
             "may improve performance, at the cost of security."),
    cl::init(false), cl::Hidden);

static cl::opt<bool> EmitDot(
    PASS_KEY "-dot",
    cl::desc(
        "For each function, emit a dot graph depicting potential LVI gadgets"),
    cl::init(false), cl::Hidden);

static cl::opt<bool> EmitDotOnly(
    PASS_KEY "-dot-only",
    cl::desc("For each function, emit a dot graph depicting potential LVI "
             "gadgets, and do not insert any fences"),
    cl::init(false), cl::Hidden);

static cl::opt<bool> EmitDotVerify(
    PASS_KEY "-dot-verify",
    cl::desc("For each function, emit a dot graph to stdout depicting "
             "potential LVI gadgets, used for testing purposes only"),
    cl::init(false), cl::Hidden);

static llvm::sys::DynamicLibrary OptimizeDL;
typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize, 
                            unsigned int *Edges, int *EdgeValues, 
                            int *CutEdges /* out */, unsigned int EdgesSize); 
static OptimizeCutT OptimizeCut = nullptr;

namespace {

struct MachineGadgetGraph : ImmutableGraph<MachineInstr *, int> {
  static constexpr int GadgetEdgeSentinel = -1;
  static constexpr MachineInstr *const ArgNodeSentinel = nullptr;

  using GraphT = ImmutableGraph<MachineInstr *, int>;
  using Node = typename GraphT::Node;
  using Edge = typename GraphT::Edge;
  using size_type = typename GraphT::size_type;
  MachineGadgetGraph(std::unique_ptr<Node[]> Nodes,
                     std::unique_ptr<Edge[]> Edges, size_type NodesSize,
                     size_type EdgesSize, int NumFences = 0, int NumGadgets = 0)
      : GraphT(std::move(Nodes), std::move(Edges), NodesSize, EdgesSize),
        NumFences(NumFences), NumGadgets(NumGadgets) {}
  static inline bool isCFGEdge(const Edge &E) {
    return E.getValue() != GadgetEdgeSentinel;
  }
  static inline bool isGadgetEdge(const Edge &E) {
    return E.getValue() == GadgetEdgeSentinel;
  }
  int NumFences;
  int NumGadgets;
};

class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass {
public:
  X86LoadValueInjectionLoadHardeningPass() : MachineFunctionPass(ID) {}

  StringRef getPassName() const override {
    return "X86 Load Value Injection (LVI) Load Hardening";
  }
  void getAnalysisUsage(AnalysisUsage &AU) const override;
  bool runOnMachineFunction(MachineFunction &MF) override;

  static char ID;

private:
  using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>;
  using Edge = MachineGadgetGraph::Edge; 
  using Node = MachineGadgetGraph::Node; 
  using EdgeSet = MachineGadgetGraph::EdgeSet;
  using NodeSet = MachineGadgetGraph::NodeSet;

  const X86Subtarget *STI;
  const TargetInstrInfo *TII;
  const TargetRegisterInfo *TRI;

  std::unique_ptr<MachineGadgetGraph>
  getGadgetGraph(MachineFunction &MF, const MachineLoopInfo &MLI,
                 const MachineDominatorTree &MDT,
                 const MachineDominanceFrontier &MDF) const;
  int hardenLoadsWithPlugin(MachineFunction &MF,
                            std::unique_ptr<MachineGadgetGraph> Graph) const;
  int hardenLoadsWithHeuristic(MachineFunction &MF, 
                               std::unique_ptr<MachineGadgetGraph> Graph) const; 
  int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G,
                                 EdgeSet &ElimEdges /* in, out */,
                                 NodeSet &ElimNodes /* in, out */) const;
  std::unique_ptr<MachineGadgetGraph>
  trimMitigatedEdges(std::unique_ptr<MachineGadgetGraph> Graph) const;
  int insertFences(MachineFunction &MF, MachineGadgetGraph &G,
                   EdgeSet &CutEdges /* in, out */) const;
  bool instrUsesRegToAccessMemory(const MachineInstr &I, unsigned Reg) const;
  bool instrUsesRegToBranch(const MachineInstr &I, unsigned Reg) const;
  inline bool isFence(const MachineInstr *MI) const {
    return MI && (MI->getOpcode() == X86::LFENCE ||
                  (STI->useLVIControlFlowIntegrity() && MI->isCall()));
  }
};

} // end anonymous namespace

namespace llvm {

template <>
struct GraphTraits<MachineGadgetGraph *>
    : GraphTraits<ImmutableGraph<MachineInstr *, int> *> {};

template <>
struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits {
  using GraphType = MachineGadgetGraph;
  using Traits = llvm::GraphTraits<GraphType *>;
  using NodeRef = typename Traits::NodeRef;
  using EdgeRef = typename Traits::EdgeRef;
  using ChildIteratorType = typename Traits::ChildIteratorType;
  using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType;

  DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} 

  std::string getNodeLabel(NodeRef Node, GraphType *) {
    if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel)
      return "ARGS";

    std::string Str;
    raw_string_ostream OS(Str);
    OS << *Node->getValue();
    return OS.str();
  }

  static std::string getNodeAttributes(NodeRef Node, GraphType *) {
    MachineInstr *MI = Node->getValue();
    if (MI == MachineGadgetGraph::ArgNodeSentinel)
      return "color = blue";
    if (MI->getOpcode() == X86::LFENCE)
      return "color = green";
    return "";
  }

  static std::string getEdgeAttributes(NodeRef, ChildIteratorType E,
                                       GraphType *) {
    int EdgeVal = (*E.getCurrent()).getValue();
    return EdgeVal >= 0 ? "label = " + std::to_string(EdgeVal)
                        : "color = red, style = \"dashed\"";
  }
};

} // end namespace llvm

constexpr MachineInstr *MachineGadgetGraph::ArgNodeSentinel;
constexpr int MachineGadgetGraph::GadgetEdgeSentinel;

char X86LoadValueInjectionLoadHardeningPass::ID = 0;

void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage(
    AnalysisUsage &AU) const {
  MachineFunctionPass::getAnalysisUsage(AU);
  AU.addRequired<MachineLoopInfo>();
  AU.addRequired<MachineDominatorTree>();
  AU.addRequired<MachineDominanceFrontier>();
  AU.setPreservesCFG();
}

static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF, 
                             MachineGadgetGraph *G) {
  WriteGraph(OS, G, /*ShortNames*/ false,
             "Speculative gadgets for \"" + MF.getName() + "\" function");
}

bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction(
    MachineFunction &MF) {
  LLVM_DEBUG(dbgs() << "***** " << getPassName() << " : " << MF.getName()
                    << " *****\n");
  STI = &MF.getSubtarget<X86Subtarget>();
  if (!STI->useLVILoadHardening())
    return false;

  // FIXME: support 32-bit
  if (!STI->is64Bit())
    report_fatal_error("LVI load hardening is only supported on 64-bit", false);

  // Don't skip functions with the "optnone" attr but participate in opt-bisect.
  const Function &F = MF.getFunction();
  if (!F.hasOptNone() && skipFunction(F))
    return false;

  ++NumFunctionsConsidered;
  TII = STI->getInstrInfo();
  TRI = STI->getRegisterInfo();
  LLVM_DEBUG(dbgs() << "Building gadget graph...\n");
  const auto &MLI = getAnalysis<MachineLoopInfo>();
  const auto &MDT = getAnalysis<MachineDominatorTree>();
  const auto &MDF = getAnalysis<MachineDominanceFrontier>();
  std::unique_ptr<MachineGadgetGraph> Graph = getGadgetGraph(MF, MLI, MDT, MDF);
  LLVM_DEBUG(dbgs() << "Building gadget graph... Done\n");
  if (Graph == nullptr)
    return false; // didn't find any gadgets

  if (EmitDotVerify) {
    writeGadgetGraph(outs(), MF, Graph.get()); 
    return false;
  }

  if (EmitDot || EmitDotOnly) {
    LLVM_DEBUG(dbgs() << "Emitting gadget graph...\n");
    std::error_code FileError;
    std::string FileName = "lvi.";
    FileName += MF.getName();
    FileName += ".dot";
    raw_fd_ostream FileOut(FileName, FileError);
    if (FileError)
      errs() << FileError.message();
    writeGadgetGraph(FileOut, MF, Graph.get()); 
    FileOut.close();
    LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n");
    if (EmitDotOnly)
      return false;
  }

  int FencesInserted;
  if (!OptimizePluginPath.empty()) {
    if (!OptimizeDL.isValid()) {
      std::string ErrorMsg;
      OptimizeDL = llvm::sys::DynamicLibrary::getPermanentLibrary(
          OptimizePluginPath.c_str(), &ErrorMsg);
      if (!ErrorMsg.empty())
        report_fatal_error("Failed to load opt plugin: \"" + ErrorMsg + '\"');
      OptimizeCut = (OptimizeCutT)OptimizeDL.getAddressOfSymbol("optimize_cut");
      if (!OptimizeCut)
        report_fatal_error("Invalid optimization plugin");
    }
    FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph));
  } else { // Use the default greedy heuristic
    FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph)); 
  }

  if (FencesInserted > 0)
    ++NumFunctionsMitigated;
  NumFences += FencesInserted;
  return (FencesInserted > 0);
}

std::unique_ptr<MachineGadgetGraph>
X86LoadValueInjectionLoadHardeningPass::getGadgetGraph(
    MachineFunction &MF, const MachineLoopInfo &MLI,
    const MachineDominatorTree &MDT,
    const MachineDominanceFrontier &MDF) const {
  using namespace rdf;

  // Build the Register Dataflow Graph using the RDF framework
  TargetOperandInfo TOI{*TII};
  DataFlowGraph DFG{MF, *TII, *TRI, MDT, MDF, TOI};
  DFG.build();
  Liveness L{MF.getRegInfo(), DFG};
  L.computePhiInfo();

  GraphBuilder Builder;
  using GraphIter = typename GraphBuilder::BuilderNodeRef;
  DenseMap<MachineInstr *, GraphIter> NodeMap;
  int FenceCount = 0, GadgetCount = 0;
  auto MaybeAddNode = [&NodeMap, &Builder](MachineInstr *MI) {
    auto Ref = NodeMap.find(MI);
    if (Ref == NodeMap.end()) {
      auto I = Builder.addVertex(MI);
      NodeMap[MI] = I;
      return std::pair<GraphIter, bool>{I, true};
    }
    return std::pair<GraphIter, bool>{Ref->getSecond(), false};
  };

  // The `Transmitters` map memoizes transmitters found for each def. If a def
  // has not yet been analyzed, then it will not appear in the map. If a def
  // has been analyzed and was determined not to have any transmitters, then
  // its list of transmitters will be empty.
  DenseMap<NodeId, std::vector<NodeId>> Transmitters;

  // Analyze all machine instructions to find gadgets and LFENCEs, adding
  // each interesting value to `Nodes`
  auto AnalyzeDef = [&](NodeAddr<DefNode *> SourceDef) {
    SmallSet<NodeId, 8> UsesVisited, DefsVisited;
    std::function<void(NodeAddr<DefNode *>)> AnalyzeDefUseChain =
        [&](NodeAddr<DefNode *> Def) {
          if (Transmitters.find(Def.Id) != Transmitters.end())
            return; // Already analyzed `Def`

          // Use RDF to find all the uses of `Def`
          rdf::NodeSet Uses;
          RegisterRef DefReg = Def.Addr->getRegRef(DFG); 
          for (auto UseID : L.getAllReachedUses(DefReg, Def)) {
            auto Use = DFG.addr<UseNode *>(UseID);
            if (Use.Addr->getFlags() & NodeAttrs::PhiRef) { // phi node
              NodeAddr<PhiNode *> Phi = Use.Addr->getOwner(DFG);
              for (auto I : L.getRealUses(Phi.Id)) {
                if (DFG.getPRI().alias(RegisterRef(I.first), DefReg)) {
                  for (auto UA : I.second)
                    Uses.emplace(UA.first);
                }
              }
            } else { // not a phi node
              Uses.emplace(UseID);
            }
          }

          // For each use of `Def`, we want to know whether:
          // (1) The use can leak the Def'ed value,
          // (2) The use can further propagate the Def'ed value to more defs
          for (auto UseID : Uses) {
            if (!UsesVisited.insert(UseID).second)
              continue; // Already visited this use of `Def`

            auto Use = DFG.addr<UseNode *>(UseID);
            assert(!(Use.Addr->getFlags() & NodeAttrs::PhiRef));
            MachineOperand &UseMO = Use.Addr->getOp();
            MachineInstr &UseMI = *UseMO.getParent();
            assert(UseMO.isReg());

            // We naively assume that an instruction propagates any loaded
            // uses to all defs unless the instruction is a call, in which
            // case all arguments will be treated as gadget sources during
            // analysis of the callee function.
            if (UseMI.isCall())
              continue;

            // Check whether this use can transmit (leak) its value.
            if (instrUsesRegToAccessMemory(UseMI, UseMO.getReg()) ||
                (!NoConditionalBranches &&
                 instrUsesRegToBranch(UseMI, UseMO.getReg()))) {
              Transmitters[Def.Id].push_back(Use.Addr->getOwner(DFG).Id);
              if (UseMI.mayLoad())
                continue; // Found a transmitting load -- no need to continue
                          // traversing its defs (i.e., this load will become
                          // a new gadget source anyways).
            }

            // Check whether the use propagates to more defs.
            NodeAddr<InstrNode *> Owner{Use.Addr->getOwner(DFG)};
            rdf::NodeList AnalyzedChildDefs;
            for (auto &ChildDef :
                 Owner.Addr->members_if(DataFlowGraph::IsDef, DFG)) {
              if (!DefsVisited.insert(ChildDef.Id).second)
                continue; // Already visited this def
              if (Def.Addr->getAttrs() & NodeAttrs::Dead)
                continue;
              if (Def.Id == ChildDef.Id)
                continue; // `Def` uses itself (e.g., increment loop counter)

              AnalyzeDefUseChain(ChildDef);

              // `Def` inherits all of its child defs' transmitters.
              for (auto TransmitterId : Transmitters[ChildDef.Id])
                Transmitters[Def.Id].push_back(TransmitterId);
            }
          }

          // Note that this statement adds `Def.Id` to the map if no
          // transmitters were found for `Def`.
          auto &DefTransmitters = Transmitters[Def.Id];

          // Remove duplicate transmitters
          llvm::sort(DefTransmitters);
          DefTransmitters.erase(
              std::unique(DefTransmitters.begin(), DefTransmitters.end()),
              DefTransmitters.end());
        };

    // Find all of the transmitters
    AnalyzeDefUseChain(SourceDef);
    auto &SourceDefTransmitters = Transmitters[SourceDef.Id];
    if (SourceDefTransmitters.empty())
      return; // No transmitters for `SourceDef`

    MachineInstr *Source = SourceDef.Addr->getFlags() & NodeAttrs::PhiRef
                               ? MachineGadgetGraph::ArgNodeSentinel
                               : SourceDef.Addr->getOp().getParent();
    auto GadgetSource = MaybeAddNode(Source);
    // Each transmitter is a sink for `SourceDef`.
    for (auto TransmitterId : SourceDefTransmitters) {
      MachineInstr *Sink = DFG.addr<StmtNode *>(TransmitterId).Addr->getCode();
      auto GadgetSink = MaybeAddNode(Sink);
      // Add the gadget edge to the graph.
      Builder.addEdge(MachineGadgetGraph::GadgetEdgeSentinel,
                      GadgetSource.first, GadgetSink.first);
      ++GadgetCount;
    }
  };

  LLVM_DEBUG(dbgs() << "Analyzing def-use chains to find gadgets\n");
  // Analyze function arguments
  NodeAddr<BlockNode *> EntryBlock = DFG.getFunc().Addr->getEntryBlock(DFG);
  for (NodeAddr<PhiNode *> ArgPhi :
       EntryBlock.Addr->members_if(DataFlowGraph::IsPhi, DFG)) {
    NodeList Defs = ArgPhi.Addr->members_if(DataFlowGraph::IsDef, DFG);
    llvm::for_each(Defs, AnalyzeDef);
  }
  // Analyze every instruction in MF
  for (NodeAddr<BlockNode *> BA : DFG.getFunc().Addr->members(DFG)) {
    for (NodeAddr<StmtNode *> SA :
         BA.Addr->members_if(DataFlowGraph::IsCode<NodeAttrs::Stmt>, DFG)) {
      MachineInstr *MI = SA.Addr->getCode();
      if (isFence(MI)) {
        MaybeAddNode(MI);
        ++FenceCount;
      } else if (MI->mayLoad()) {
        NodeList Defs = SA.Addr->members_if(DataFlowGraph::IsDef, DFG);
        llvm::for_each(Defs, AnalyzeDef);
      }
    }
  }
  LLVM_DEBUG(dbgs() << "Found " << FenceCount << " fences\n");
  LLVM_DEBUG(dbgs() << "Found " << GadgetCount << " gadgets\n");
  if (GadgetCount == 0)
    return nullptr;
  NumGadgets += GadgetCount;

  // Traverse CFG to build the rest of the graph
  SmallSet<MachineBasicBlock *, 8> BlocksVisited;
  std::function<void(MachineBasicBlock *, GraphIter, unsigned)> TraverseCFG =
      [&](MachineBasicBlock *MBB, GraphIter GI, unsigned ParentDepth) {
        unsigned LoopDepth = MLI.getLoopDepth(MBB);
        if (!MBB->empty()) {
          // Always add the first instruction in each block
          auto NI = MBB->begin();
          auto BeginBB = MaybeAddNode(&*NI);
          Builder.addEdge(ParentDepth, GI, BeginBB.first);
          if (!BlocksVisited.insert(MBB).second)
            return;

          // Add any instructions within the block that are gadget components
          GI = BeginBB.first;
          while (++NI != MBB->end()) {
            auto Ref = NodeMap.find(&*NI);
            if (Ref != NodeMap.end()) {
              Builder.addEdge(LoopDepth, GI, Ref->getSecond());
              GI = Ref->getSecond();
            }
          }

          // Always add the terminator instruction, if one exists
          auto T = MBB->getFirstTerminator();
          if (T != MBB->end()) {
            auto EndBB = MaybeAddNode(&*T);
            if (EndBB.second)
              Builder.addEdge(LoopDepth, GI, EndBB.first);
            GI = EndBB.first;
          }
        }
        for (MachineBasicBlock *Succ : MBB->successors())
          TraverseCFG(Succ, GI, LoopDepth);
      };
  // ArgNodeSentinel is a pseudo-instruction that represents MF args in the
  // GadgetGraph
  GraphIter ArgNode = MaybeAddNode(MachineGadgetGraph::ArgNodeSentinel).first;
  TraverseCFG(&MF.front(), ArgNode, 0);
  std::unique_ptr<MachineGadgetGraph> G{Builder.get(FenceCount, GadgetCount)};
  LLVM_DEBUG(dbgs() << "Found " << G->nodes_size() << " nodes\n");
  return G;
}

// Returns the number of remaining gadget edges that could not be eliminated
int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes(
    MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, 
    NodeSet &ElimNodes /* in, out */) const { 
  if (G.NumFences > 0) {
    // Eliminate fences and CFG edges that ingress and egress the fence, as
    // they are trivially mitigated.
    for (const Edge &E : G.edges()) { 
      const Node *Dest = E.getDest(); 
      if (isFence(Dest->getValue())) {
        ElimNodes.insert(*Dest);
        ElimEdges.insert(E);
        for (const Edge &DE : Dest->edges()) 
          ElimEdges.insert(DE);
      }
    }
  }

  // Find and eliminate gadget edges that have been mitigated.
  int MitigatedGadgets = 0, RemainingGadgets = 0;
  NodeSet ReachableNodes{G}; 
  for (const Node &RootN : G.nodes()) { 
    if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge))
      continue; // skip this node if it isn't a gadget source

    // Find all of the nodes that are CFG-reachable from RootN using DFS
    ReachableNodes.clear();
    std::function<void(const Node *, bool)> FindReachableNodes = 
        [&](const Node *N, bool FirstNode) { 
          if (!FirstNode) 
            ReachableNodes.insert(*N); 
          for (const Edge &E : N->edges()) { 
            const Node *Dest = E.getDest(); 
            if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) && 
                !ReachableNodes.contains(*Dest)) 
              FindReachableNodes(Dest, false); 
          } 
        }; 
    FindReachableNodes(&RootN, true);

    // Any gadget whose sink is unreachable has been mitigated
    for (const Edge &E : RootN.edges()) { 
      if (MachineGadgetGraph::isGadgetEdge(E)) {
        if (ReachableNodes.contains(*E.getDest())) {
          // This gadget's sink is reachable
          ++RemainingGadgets;
        } else { // This gadget's sink is unreachable, and therefore mitigated
          ++MitigatedGadgets;
          ElimEdges.insert(E);
        }
      }
    }
  }
  return RemainingGadgets;
}

std::unique_ptr<MachineGadgetGraph>
X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges(
    std::unique_ptr<MachineGadgetGraph> Graph) const {
  NodeSet ElimNodes{*Graph}; 
  EdgeSet ElimEdges{*Graph}; 
  int RemainingGadgets =
      elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes);
  if (ElimEdges.empty() && ElimNodes.empty()) {
    Graph->NumFences = 0;
    Graph->NumGadgets = RemainingGadgets;
  } else {
    Graph = GraphBuilder::trim(*Graph, ElimNodes, ElimEdges, 0 /* NumFences */,
                               RemainingGadgets);
  }
  return Graph;
}

int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin(
    MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const {
  int FencesInserted = 0;

  do {
    LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n");
    Graph = trimMitigatedEdges(std::move(Graph));
    LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n");
    if (Graph->NumGadgets == 0)
      break;

    LLVM_DEBUG(dbgs() << "Cutting edges...\n");
    EdgeSet CutEdges{*Graph};
    auto Nodes = std::make_unique<unsigned int[]>(Graph->nodes_size() +
                                                  1 /* terminator node */);
    auto Edges = std::make_unique<unsigned int[]>(Graph->edges_size());
    auto EdgeCuts = std::make_unique<int[]>(Graph->edges_size());
    auto EdgeValues = std::make_unique<int[]>(Graph->edges_size());
    for (const Node &N : Graph->nodes()) { 
      Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin());
    }
    Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node
    for (const Edge &E : Graph->edges()) { 
      Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest());
      EdgeValues[Graph->getEdgeIndex(E)] = E.getValue();
    }
    OptimizeCut(Nodes.get(), Graph->nodes_size(), Edges.get(), EdgeValues.get(),
                EdgeCuts.get(), Graph->edges_size());
    for (int I = 0; I < Graph->edges_size(); ++I)
      if (EdgeCuts[I])
        CutEdges.set(I);
    LLVM_DEBUG(dbgs() << "Cutting edges... Done\n");
    LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n");

    LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n");
    FencesInserted += insertFences(MF, *Graph, CutEdges);
    LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n");
    LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n");

    Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges); 
  } while (true);

  return FencesInserted;
}

int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic( 
    MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const {
  // If `MF` does not have any fences, then no gadgets would have been 
  // mitigated at this point. 
  if (Graph->NumFences > 0) { 
    LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); 
    Graph = trimMitigatedEdges(std::move(Graph)); 
    LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); 
  } 
 
  if (Graph->NumGadgets == 0)
    return 0;

  LLVM_DEBUG(dbgs() << "Cutting edges...\n");
  EdgeSet CutEdges{*Graph}; 

  // Begin by collecting all ingress CFG edges for each node 
  DenseMap<const Node *, SmallVector<const Edge *, 2>> IngressEdgeMap; 
  for (const Edge &E : Graph->edges()) 
    if (MachineGadgetGraph::isCFGEdge(E)) 
      IngressEdgeMap[E.getDest()].push_back(&E); 

  // For each gadget edge, make cuts that guarantee the gadget will be 
  // mitigated. A computationally efficient way to achieve this is to either: 
  // (a) cut all egress CFG edges from the gadget source, or 
  // (b) cut all ingress CFG edges to the gadget sink. 
  // 
  // Moreover, the algorithm tries not to make a cut into a loop by preferring 
  // to make a (b)-type cut if the gadget source resides at a greater loop depth 
  // than the gadget sink, or an (a)-type cut otherwise. 
  for (const Node &N : Graph->nodes()) { 
    for (const Edge &E : N.edges()) { 
      if (!MachineGadgetGraph::isGadgetEdge(E)) 
        continue;

      SmallVector<const Edge *, 2> EgressEdges; 
      SmallVector<const Edge *, 2> &IngressEdges = IngressEdgeMap[E.getDest()]; 
      for (const Edge &EgressEdge : N.edges()) 
        if (MachineGadgetGraph::isCFGEdge(EgressEdge)) 
          EgressEdges.push_back(&EgressEdge); 
 
      int EgressCutCost = 0, IngressCutCost = 0; 
      for (const Edge *EgressEdge : EgressEdges) 
        if (!CutEdges.contains(*EgressEdge)) 
          EgressCutCost += EgressEdge->getValue(); 
      for (const Edge *IngressEdge : IngressEdges) 
        if (!CutEdges.contains(*IngressEdge)) 
          IngressCutCost += IngressEdge->getValue(); 
 
      auto &EdgesToCut = 
          IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges; 
      for (const Edge *E : EdgesToCut) 
        CutEdges.insert(*E); 
    }
  } 
  LLVM_DEBUG(dbgs() << "Cutting edges... Done\n");
  LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n");

  LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n");
  int FencesInserted = insertFences(MF, *Graph, CutEdges);
  LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n");
  LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n");

  return FencesInserted;
}

int X86LoadValueInjectionLoadHardeningPass::insertFences(
    MachineFunction &MF, MachineGadgetGraph &G,
    EdgeSet &CutEdges /* in, out */) const {
  int FencesInserted = 0;
  for (const Node &N : G.nodes()) { 
    for (const Edge &E : N.edges()) { 
      if (CutEdges.contains(E)) {
        MachineInstr *MI = N.getValue(), *Prev;
        MachineBasicBlock *MBB;                  // Insert an LFENCE in this MBB
        MachineBasicBlock::iterator InsertionPt; // ...at this point
        if (MI == MachineGadgetGraph::ArgNodeSentinel) {
          // insert LFENCE at beginning of entry block
          MBB = &MF.front();
          InsertionPt = MBB->begin();
          Prev = nullptr;
        } else if (MI->isBranch()) { // insert the LFENCE before the branch
          MBB = MI->getParent();
          InsertionPt = MI;
          Prev = MI->getPrevNode();
          // Remove all egress CFG edges from this branch because the inserted
          // LFENCE prevents gadgets from crossing the branch.
          for (const Edge &E : N.edges()) { 
            if (MachineGadgetGraph::isCFGEdge(E))
              CutEdges.insert(E);
          }
        } else { // insert the LFENCE after the instruction
          MBB = MI->getParent();
          InsertionPt = MI->getNextNode() ? MI->getNextNode() : MBB->end();
          Prev = InsertionPt == MBB->end()
                     ? (MBB->empty() ? nullptr : &MBB->back())
                     : InsertionPt->getPrevNode();
        }
        // Ensure this insertion is not redundant (two LFENCEs in sequence).
        if ((InsertionPt == MBB->end() || !isFence(&*InsertionPt)) &&
            (!Prev || !isFence(Prev))) {
          BuildMI(*MBB, InsertionPt, DebugLoc(), TII->get(X86::LFENCE));
          ++FencesInserted;
        }
      }
    }
  }
  return FencesInserted;
}

bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToAccessMemory(
    const MachineInstr &MI, unsigned Reg) const {
  if (!MI.mayLoadOrStore() || MI.getOpcode() == X86::MFENCE ||
      MI.getOpcode() == X86::SFENCE || MI.getOpcode() == X86::LFENCE)
    return false;

  // FIXME: This does not handle pseudo loading instruction like TCRETURN*
  const MCInstrDesc &Desc = MI.getDesc();
  int MemRefBeginIdx = X86II::getMemoryOperandNo(Desc.TSFlags);
  if (MemRefBeginIdx < 0) {
    LLVM_DEBUG(dbgs() << "Warning: unable to obtain memory operand for loading "
                         "instruction:\n";
               MI.print(dbgs()); dbgs() << '\n';);
    return false;
  }
  MemRefBeginIdx += X86II::getOperandBias(Desc);

  const MachineOperand &BaseMO =
      MI.getOperand(MemRefBeginIdx + X86::AddrBaseReg);
  const MachineOperand &IndexMO =
      MI.getOperand(MemRefBeginIdx + X86::AddrIndexReg);
  return (BaseMO.isReg() && BaseMO.getReg() != X86::NoRegister &&
          TRI->regsOverlap(BaseMO.getReg(), Reg)) ||
         (IndexMO.isReg() && IndexMO.getReg() != X86::NoRegister &&
          TRI->regsOverlap(IndexMO.getReg(), Reg));
}

bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToBranch(
    const MachineInstr &MI, unsigned Reg) const {
  if (!MI.isConditionalBranch())
    return false;
  for (const MachineOperand &Use : MI.uses())
    if (Use.isReg() && Use.getReg() == Reg)
      return true;
  return false;
}

INITIALIZE_PASS_BEGIN(X86LoadValueInjectionLoadHardeningPass, PASS_KEY,
                      "X86 LVI load hardening", false, false)
INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo)
INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
INITIALIZE_PASS_DEPENDENCY(MachineDominanceFrontier)
INITIALIZE_PASS_END(X86LoadValueInjectionLoadHardeningPass, PASS_KEY,
                    "X86 LVI load hardening", false, false)

FunctionPass *llvm::createX86LoadValueInjectionLoadHardeningPass() {
  return new X86LoadValueInjectionLoadHardeningPass();
}