aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm16/lib/Target/WebAssembly/WebAssemblyFixBrTableDefaults.cpp
blob: fa5b4a508fa572bb9504bd75c1a253bf2901aa43 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
//=- WebAssemblyFixBrTableDefaults.cpp - Fix br_table default branch targets -//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file This file implements a pass that eliminates redundant range checks
/// guarding br_table instructions. Since jump tables on most targets cannot
/// handle out of range indices, LLVM emits these checks before most jump
/// tables. But br_table takes a default branch target as an argument, so it
/// does not need the range checks.
///
//===----------------------------------------------------------------------===//

#include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
#include "WebAssembly.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/Pass.h"

using namespace llvm;

#define DEBUG_TYPE "wasm-fix-br-table-defaults"

namespace {

class WebAssemblyFixBrTableDefaults final : public MachineFunctionPass {
  StringRef getPassName() const override {
    return "WebAssembly Fix br_table Defaults";
  }

  bool runOnMachineFunction(MachineFunction &MF) override;

public:
  static char ID; // Pass identification, replacement for typeid
  WebAssemblyFixBrTableDefaults() : MachineFunctionPass(ID) {}
};

char WebAssemblyFixBrTableDefaults::ID = 0;

// Target indepedent selection dag assumes that it is ok to use PointerTy
// as the index for a "switch", whereas Wasm so far only has a 32-bit br_table.
// See e.g. SelectionDAGBuilder::visitJumpTableHeader
// We have a 64-bit br_table in the tablegen defs as a result, which does get
// selected, and thus we get incorrect truncates/extensions happening on
// wasm64. Here we fix that.
void fixBrTableIndex(MachineInstr &MI, MachineBasicBlock *MBB,
                     MachineFunction &MF) {
  // Only happens on wasm64.
  auto &WST = MF.getSubtarget<WebAssemblySubtarget>();
  if (!WST.hasAddr64())
    return;

  assert(MI.getDesc().getOpcode() == WebAssembly::BR_TABLE_I64 &&
         "64-bit br_table pseudo instruction expected");

  // Find extension op, if any. It sits in the previous BB before the branch.
  auto ExtMI = MF.getRegInfo().getVRegDef(MI.getOperand(0).getReg());
  if (ExtMI->getOpcode() == WebAssembly::I64_EXTEND_U_I32) {
    // Unnecessarily extending a 32-bit value to 64, remove it.
    auto ExtDefReg = ExtMI->getOperand(0).getReg();
    assert(MI.getOperand(0).getReg() == ExtDefReg);
    MI.getOperand(0).setReg(ExtMI->getOperand(1).getReg());
    if (MF.getRegInfo().use_nodbg_empty(ExtDefReg)) {
      // No more users of extend, delete it.
      ExtMI->eraseFromParent();
    }
  } else {
    // Incoming 64-bit value that needs to be truncated.
    Register Reg32 =
        MF.getRegInfo().createVirtualRegister(&WebAssembly::I32RegClass);
    BuildMI(*MBB, MI.getIterator(), MI.getDebugLoc(),
            WST.getInstrInfo()->get(WebAssembly::I32_WRAP_I64), Reg32)
        .addReg(MI.getOperand(0).getReg());
    MI.getOperand(0).setReg(Reg32);
  }

  // We now have a 32-bit operand in all cases, so change the instruction
  // accordingly.
  MI.setDesc(WST.getInstrInfo()->get(WebAssembly::BR_TABLE_I32));
}

// `MI` is a br_table instruction with a dummy default target argument. This
// function finds and adds the default target argument and removes any redundant
// range check preceding the br_table. Returns the MBB that the br_table is
// moved into so it can be removed from further consideration, or nullptr if the
// br_table cannot be optimized.
MachineBasicBlock *fixBrTableDefault(MachineInstr &MI, MachineBasicBlock *MBB,
                                     MachineFunction &MF) {
  // Get the header block, which contains the redundant range check.
  assert(MBB->pred_size() == 1 && "Expected a single guard predecessor");
  auto *HeaderMBB = *MBB->pred_begin();

  // Find the conditional jump to the default target. If it doesn't exist, the
  // default target is unreachable anyway, so we can keep the existing dummy
  // target.
  MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
  SmallVector<MachineOperand, 2> Cond;
  const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
  bool Analyzed = !TII.analyzeBranch(*HeaderMBB, TBB, FBB, Cond);
  assert(Analyzed && "Could not analyze jump header branches");
  (void)Analyzed;

  // Here are the possible outcomes. '_' is nullptr, `J` is the jump table block
  // aka MBB, 'D' is the default block.
  //
  // TBB | FBB | Meaning
  //  _  |  _  | No default block, header falls through to jump table
  //  J  |  _  | No default block, header jumps to the jump table
  //  D  |  _  | Header jumps to the default and falls through to the jump table
  //  D  |  J  | Header jumps to the default and also to the jump table
  if (TBB && TBB != MBB) {
    assert((FBB == nullptr || FBB == MBB) &&
           "Expected jump or fallthrough to br_table block");
    assert(Cond.size() == 2 && Cond[1].isReg() && "Unexpected condition info");

    // If the range check checks an i64 value, we cannot optimize it out because
    // the i64 index is truncated to an i32, making values over 2^32
    // indistinguishable from small numbers. There are also other strange edge
    // cases that can arise in practice that we don't want to reason about, so
    // conservatively only perform the optimization if the range check is the
    // normal case of an i32.gt_u.
    MachineRegisterInfo &MRI = MF.getRegInfo();
    auto *RangeCheck = MRI.getVRegDef(Cond[1].getReg());
    assert(RangeCheck != nullptr);
    if (RangeCheck->getOpcode() != WebAssembly::GT_U_I32)
      return nullptr;

    // Remove the dummy default target and install the real one.
    MI.removeOperand(MI.getNumExplicitOperands() - 1);
    MI.addOperand(MF, MachineOperand::CreateMBB(TBB));
  }

  // Remove any branches from the header and splice in the jump table instead
  TII.removeBranch(*HeaderMBB, nullptr);
  HeaderMBB->splice(HeaderMBB->end(), MBB, MBB->begin(), MBB->end());

  // Update CFG to skip the old jump table block. Remove shared successors
  // before transferring to avoid duplicated successors.
  HeaderMBB->removeSuccessor(MBB);
  for (auto &Succ : MBB->successors())
    if (HeaderMBB->isSuccessor(Succ))
      HeaderMBB->removeSuccessor(Succ);
  HeaderMBB->transferSuccessorsAndUpdatePHIs(MBB);

  // Remove the old jump table block from the function
  MF.erase(MBB);

  return HeaderMBB;
}

bool WebAssemblyFixBrTableDefaults::runOnMachineFunction(MachineFunction &MF) {
  LLVM_DEBUG(dbgs() << "********** Fixing br_table Default Targets **********\n"
                       "********** Function: "
                    << MF.getName() << '\n');

  bool Changed = false;
  SmallPtrSet<MachineBasicBlock *, 16> MBBSet;
  for (auto &MBB : MF)
    MBBSet.insert(&MBB);

  while (!MBBSet.empty()) {
    MachineBasicBlock *MBB = *MBBSet.begin();
    MBBSet.erase(MBB);
    for (auto &MI : *MBB) {
      if (WebAssembly::isBrTable(MI)) {
        fixBrTableIndex(MI, MBB, MF);
        auto *Fixed = fixBrTableDefault(MI, MBB, MF);
        if (Fixed != nullptr) {
          MBBSet.erase(Fixed);
          Changed = true;
        }
        break;
      }
    }
  }

  if (Changed) {
    // We rewrote part of the function; recompute relevant things.
    MF.RenumberBlocks();
    return true;
  }

  return false;
}

} // end anonymous namespace

INITIALIZE_PASS(WebAssemblyFixBrTableDefaults, DEBUG_TYPE,
                "Removes range checks and sets br_table default targets", false,
                false)

FunctionPass *llvm::createWebAssemblyFixBrTableDefaults() {
  return new WebAssemblyFixBrTableDefaults();
}