aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
diff options
context:
space:
mode:
authorshadchin <shadchin@yandex-team.ru>2022-02-10 16:44:30 +0300
committerDaniil Cherednik <dcherednik@yandex-team.ru>2022-02-10 16:44:30 +0300
commit2598ef1d0aee359b4b6d5fdd1758916d5907d04f (patch)
tree012bb94d777798f1f56ac1cec429509766d05181 /contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
parent6751af0b0c1b952fede40b19b71da8025b5d8bcf (diff)
downloadydb-2598ef1d0aee359b4b6d5fdd1758916d5907d04f.tar.gz
Restoring authorship annotation for <shadchin@yandex-team.ru>. Commit 1 of 2.
Diffstat (limited to 'contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp')
-rw-r--r--contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp374
1 files changed, 187 insertions, 187 deletions
diff --git a/contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp b/contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
index fdd04cb77f..bf3190ce93 100644
--- a/contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
+++ b/contrib/libs/llvm12/lib/Target/AArch64/GISel/AArch64PostLegalizerCombiner.cpp
@@ -1,22 +1,22 @@
-//=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
+//=== AArch64PostLegalizerCombiner.cpp --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
-///
-/// \file
-/// Post-legalization combines on generic MachineInstrs.
-///
-/// The combines here must preserve instruction legality.
-///
-/// Lowering combines (e.g. pseudo matching) should be handled by
-/// AArch64PostLegalizerLowering.
-///
-/// Combines which don't rely on instruction legality should go in the
-/// AArch64PreLegalizerCombiner.
-///
+///
+/// \file
+/// Post-legalization combines on generic MachineInstrs.
+///
+/// The combines here must preserve instruction legality.
+///
+/// Lowering combines (e.g. pseudo matching) should be handled by
+/// AArch64PostLegalizerLowering.
+///
+/// Combines which don't rely on instruction legality should go in the
+/// AArch64PreLegalizerCombiner.
+///
//===----------------------------------------------------------------------===//
#include "AArch64TargetMachine.h"
@@ -24,12 +24,12 @@
#include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
#include "llvm/CodeGen/GlobalISel/CombinerInfo.h"
#include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
-#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
-#include "llvm/CodeGen/GlobalISel/Utils.h"
+#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
+#include "llvm/CodeGen/GlobalISel/Utils.h"
#include "llvm/CodeGen/MachineDominators.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
-#include "llvm/CodeGen/MachineRegisterInfo.h"
-#include "llvm/CodeGen/TargetOpcodes.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/TargetOpcodes.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/Support/Debug.h"
@@ -37,202 +37,202 @@
using namespace llvm;
-/// This combine tries do what performExtractVectorEltCombine does in SDAG.
-/// Rewrite for pairwise fadd pattern
-/// (s32 (g_extract_vector_elt
-/// (g_fadd (vXs32 Other)
-/// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
-/// ->
-/// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
-/// (g_extract_vector_elt (vXs32 Other) 1))
-bool matchExtractVecEltPairwiseAdd(
- MachineInstr &MI, MachineRegisterInfo &MRI,
- std::tuple<unsigned, LLT, Register> &MatchInfo) {
- Register Src1 = MI.getOperand(1).getReg();
- Register Src2 = MI.getOperand(2).getReg();
- LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
-
- auto Cst = getConstantVRegValWithLookThrough(Src2, MRI);
- if (!Cst || Cst->Value != 0)
+/// This combine tries do what performExtractVectorEltCombine does in SDAG.
+/// Rewrite for pairwise fadd pattern
+/// (s32 (g_extract_vector_elt
+/// (g_fadd (vXs32 Other)
+/// (g_vector_shuffle (vXs32 Other) undef <1,X,...> )) 0))
+/// ->
+/// (s32 (g_fadd (g_extract_vector_elt (vXs32 Other) 0)
+/// (g_extract_vector_elt (vXs32 Other) 1))
+bool matchExtractVecEltPairwiseAdd(
+ MachineInstr &MI, MachineRegisterInfo &MRI,
+ std::tuple<unsigned, LLT, Register> &MatchInfo) {
+ Register Src1 = MI.getOperand(1).getReg();
+ Register Src2 = MI.getOperand(2).getReg();
+ LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
+
+ auto Cst = getConstantVRegValWithLookThrough(Src2, MRI);
+ if (!Cst || Cst->Value != 0)
return false;
- // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
+ // SDAG also checks for FullFP16, but this looks to be beneficial anyway.
- // Now check for an fadd operation. TODO: expand this for integer add?
- auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
- if (!FAddMI)
+ // Now check for an fadd operation. TODO: expand this for integer add?
+ auto *FAddMI = getOpcodeDef(TargetOpcode::G_FADD, Src1, MRI);
+ if (!FAddMI)
return false;
- // If we add support for integer add, must restrict these types to just s64.
- unsigned DstSize = DstTy.getSizeInBits();
- if (DstSize != 16 && DstSize != 32 && DstSize != 64)
+ // If we add support for integer add, must restrict these types to just s64.
+ unsigned DstSize = DstTy.getSizeInBits();
+ if (DstSize != 16 && DstSize != 32 && DstSize != 64)
return false;
- Register Src1Op1 = FAddMI->getOperand(1).getReg();
- Register Src1Op2 = FAddMI->getOperand(2).getReg();
- MachineInstr *Shuffle =
- getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
- MachineInstr *Other = MRI.getVRegDef(Src1Op1);
- if (!Shuffle) {
- Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
- Other = MRI.getVRegDef(Src1Op2);
+ Register Src1Op1 = FAddMI->getOperand(1).getReg();
+ Register Src1Op2 = FAddMI->getOperand(2).getReg();
+ MachineInstr *Shuffle =
+ getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op2, MRI);
+ MachineInstr *Other = MRI.getVRegDef(Src1Op1);
+ if (!Shuffle) {
+ Shuffle = getOpcodeDef(TargetOpcode::G_SHUFFLE_VECTOR, Src1Op1, MRI);
+ Other = MRI.getVRegDef(Src1Op2);
}
- // We're looking for a shuffle that moves the second element to index 0.
- if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
- Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
- std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
- std::get<1>(MatchInfo) = DstTy;
- std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
+ // We're looking for a shuffle that moves the second element to index 0.
+ if (Shuffle && Shuffle->getOperand(3).getShuffleMask()[0] == 1 &&
+ Other == MRI.getVRegDef(Shuffle->getOperand(1).getReg())) {
+ std::get<0>(MatchInfo) = TargetOpcode::G_FADD;
+ std::get<1>(MatchInfo) = DstTy;
+ std::get<2>(MatchInfo) = Other->getOperand(0).getReg();
return true;
}
return false;
}
-bool applyExtractVecEltPairwiseAdd(
- MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
- std::tuple<unsigned, LLT, Register> &MatchInfo) {
- unsigned Opc = std::get<0>(MatchInfo);
- assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
- // We want to generate two extracts of elements 0 and 1, and add them.
- LLT Ty = std::get<1>(MatchInfo);
- Register Src = std::get<2>(MatchInfo);
- LLT s64 = LLT::scalar(64);
- B.setInstrAndDebugLoc(MI);
- auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
- auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
- B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
- MI.eraseFromParent();
+bool applyExtractVecEltPairwiseAdd(
+ MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
+ std::tuple<unsigned, LLT, Register> &MatchInfo) {
+ unsigned Opc = std::get<0>(MatchInfo);
+ assert(Opc == TargetOpcode::G_FADD && "Unexpected opcode!");
+ // We want to generate two extracts of elements 0 and 1, and add them.
+ LLT Ty = std::get<1>(MatchInfo);
+ Register Src = std::get<2>(MatchInfo);
+ LLT s64 = LLT::scalar(64);
+ B.setInstrAndDebugLoc(MI);
+ auto Elt0 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 0));
+ auto Elt1 = B.buildExtractVectorElement(Ty, Src, B.buildConstant(s64, 1));
+ B.buildInstr(Opc, {MI.getOperand(0).getReg()}, {Elt0, Elt1});
+ MI.eraseFromParent();
return true;
}
-static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
- // TODO: check if extended build vector as well.
- unsigned Opc = MRI.getVRegDef(R)->getOpcode();
- return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
+static bool isSignExtended(Register R, MachineRegisterInfo &MRI) {
+ // TODO: check if extended build vector as well.
+ unsigned Opc = MRI.getVRegDef(R)->getOpcode();
+ return Opc == TargetOpcode::G_SEXT || Opc == TargetOpcode::G_SEXT_INREG;
}
-static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
- // TODO: check if extended build vector as well.
- return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
+static bool isZeroExtended(Register R, MachineRegisterInfo &MRI) {
+ // TODO: check if extended build vector as well.
+ return MRI.getVRegDef(R)->getOpcode() == TargetOpcode::G_ZEXT;
}
-bool matchAArch64MulConstCombine(
- MachineInstr &MI, MachineRegisterInfo &MRI,
- std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
- assert(MI.getOpcode() == TargetOpcode::G_MUL);
- Register LHS = MI.getOperand(1).getReg();
- Register RHS = MI.getOperand(2).getReg();
- Register Dst = MI.getOperand(0).getReg();
- const LLT Ty = MRI.getType(LHS);
-
- // The below optimizations require a constant RHS.
- auto Const = getConstantVRegValWithLookThrough(RHS, MRI);
- if (!Const)
+bool matchAArch64MulConstCombine(
+ MachineInstr &MI, MachineRegisterInfo &MRI,
+ std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
+ assert(MI.getOpcode() == TargetOpcode::G_MUL);
+ Register LHS = MI.getOperand(1).getReg();
+ Register RHS = MI.getOperand(2).getReg();
+ Register Dst = MI.getOperand(0).getReg();
+ const LLT Ty = MRI.getType(LHS);
+
+ // The below optimizations require a constant RHS.
+ auto Const = getConstantVRegValWithLookThrough(RHS, MRI);
+ if (!Const)
return false;
- const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits());
- // The following code is ported from AArch64ISelLowering.
- // Multiplication of a power of two plus/minus one can be done more
- // cheaply as as shift+add/sub. For now, this is true unilaterally. If
- // future CPUs have a cheaper MADD instruction, this may need to be
- // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
- // 64-bit is 5 cycles, so this is always a win.
- // More aggressively, some multiplications N0 * C can be lowered to
- // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
- // e.g. 6=3*2=(2+1)*2.
- // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
- // which equals to (1+2)*16-(1+2).
- // TrailingZeroes is used to test if the mul can be lowered to
- // shift+add+shift.
- unsigned TrailingZeroes = ConstValue.countTrailingZeros();
- if (TrailingZeroes) {
- // Conservatively do not lower to shift+add+shift if the mul might be
- // folded into smul or umul.
- if (MRI.hasOneNonDBGUse(LHS) &&
- (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
- return false;
- // Conservatively do not lower to shift+add+shift if the mul might be
- // folded into madd or msub.
- if (MRI.hasOneNonDBGUse(Dst)) {
- MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
- if (UseMI.getOpcode() == TargetOpcode::G_ADD ||
- UseMI.getOpcode() == TargetOpcode::G_SUB)
- return false;
- }
- }
- // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
- // and shift+add+shift.
- APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
-
- unsigned ShiftAmt, AddSubOpc;
- // Is the shifted value the LHS operand of the add/sub?
- bool ShiftValUseIsLHS = true;
- // Do we need to negate the result?
- bool NegateResult = false;
-
- if (ConstValue.isNonNegative()) {
- // (mul x, 2^N + 1) => (add (shl x, N), x)
- // (mul x, 2^N - 1) => (sub (shl x, N), x)
- // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
- APInt SCVMinus1 = ShiftedConstValue - 1;
- APInt CVPlus1 = ConstValue + 1;
- if (SCVMinus1.isPowerOf2()) {
- ShiftAmt = SCVMinus1.logBase2();
- AddSubOpc = TargetOpcode::G_ADD;
- } else if (CVPlus1.isPowerOf2()) {
- ShiftAmt = CVPlus1.logBase2();
- AddSubOpc = TargetOpcode::G_SUB;
- } else
- return false;
- } else {
- // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
- // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
- APInt CVNegPlus1 = -ConstValue + 1;
- APInt CVNegMinus1 = -ConstValue - 1;
- if (CVNegPlus1.isPowerOf2()) {
- ShiftAmt = CVNegPlus1.logBase2();
- AddSubOpc = TargetOpcode::G_SUB;
- ShiftValUseIsLHS = false;
- } else if (CVNegMinus1.isPowerOf2()) {
- ShiftAmt = CVNegMinus1.logBase2();
- AddSubOpc = TargetOpcode::G_ADD;
- NegateResult = true;
- } else
- return false;
- }
-
- if (NegateResult && TrailingZeroes)
+ const APInt ConstValue = Const->Value.sextOrSelf(Ty.getSizeInBits());
+ // The following code is ported from AArch64ISelLowering.
+ // Multiplication of a power of two plus/minus one can be done more
+ // cheaply as as shift+add/sub. For now, this is true unilaterally. If
+ // future CPUs have a cheaper MADD instruction, this may need to be
+ // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
+ // 64-bit is 5 cycles, so this is always a win.
+ // More aggressively, some multiplications N0 * C can be lowered to
+ // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
+ // e.g. 6=3*2=(2+1)*2.
+ // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
+ // which equals to (1+2)*16-(1+2).
+ // TrailingZeroes is used to test if the mul can be lowered to
+ // shift+add+shift.
+ unsigned TrailingZeroes = ConstValue.countTrailingZeros();
+ if (TrailingZeroes) {
+ // Conservatively do not lower to shift+add+shift if the mul might be
+ // folded into smul or umul.
+ if (MRI.hasOneNonDBGUse(LHS) &&
+ (isSignExtended(LHS, MRI) || isZeroExtended(LHS, MRI)))
+ return false;
+ // Conservatively do not lower to shift+add+shift if the mul might be
+ // folded into madd or msub.
+ if (MRI.hasOneNonDBGUse(Dst)) {
+ MachineInstr &UseMI = *MRI.use_instr_begin(Dst);
+ if (UseMI.getOpcode() == TargetOpcode::G_ADD ||
+ UseMI.getOpcode() == TargetOpcode::G_SUB)
+ return false;
+ }
+ }
+ // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
+ // and shift+add+shift.
+ APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
+
+ unsigned ShiftAmt, AddSubOpc;
+ // Is the shifted value the LHS operand of the add/sub?
+ bool ShiftValUseIsLHS = true;
+ // Do we need to negate the result?
+ bool NegateResult = false;
+
+ if (ConstValue.isNonNegative()) {
+ // (mul x, 2^N + 1) => (add (shl x, N), x)
+ // (mul x, 2^N - 1) => (sub (shl x, N), x)
+ // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
+ APInt SCVMinus1 = ShiftedConstValue - 1;
+ APInt CVPlus1 = ConstValue + 1;
+ if (SCVMinus1.isPowerOf2()) {
+ ShiftAmt = SCVMinus1.logBase2();
+ AddSubOpc = TargetOpcode::G_ADD;
+ } else if (CVPlus1.isPowerOf2()) {
+ ShiftAmt = CVPlus1.logBase2();
+ AddSubOpc = TargetOpcode::G_SUB;
+ } else
+ return false;
+ } else {
+ // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
+ // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
+ APInt CVNegPlus1 = -ConstValue + 1;
+ APInt CVNegMinus1 = -ConstValue - 1;
+ if (CVNegPlus1.isPowerOf2()) {
+ ShiftAmt = CVNegPlus1.logBase2();
+ AddSubOpc = TargetOpcode::G_SUB;
+ ShiftValUseIsLHS = false;
+ } else if (CVNegMinus1.isPowerOf2()) {
+ ShiftAmt = CVNegMinus1.logBase2();
+ AddSubOpc = TargetOpcode::G_ADD;
+ NegateResult = true;
+ } else
+ return false;
+ }
+
+ if (NegateResult && TrailingZeroes)
return false;
- ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
- auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
- auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
-
- Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
- Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
- auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
- assert(!(NegateResult && TrailingZeroes) &&
- "NegateResult and TrailingZeroes cannot both be true for now.");
- // Negate the result.
- if (NegateResult) {
- B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
- return;
- }
- // Shift the result.
- if (TrailingZeroes) {
- B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
- return;
- }
- B.buildCopy(DstReg, Res.getReg(0));
- };
+ ApplyFn = [=](MachineIRBuilder &B, Register DstReg) {
+ auto Shift = B.buildConstant(LLT::scalar(64), ShiftAmt);
+ auto ShiftedVal = B.buildShl(Ty, LHS, Shift);
+
+ Register AddSubLHS = ShiftValUseIsLHS ? ShiftedVal.getReg(0) : LHS;
+ Register AddSubRHS = ShiftValUseIsLHS ? LHS : ShiftedVal.getReg(0);
+ auto Res = B.buildInstr(AddSubOpc, {Ty}, {AddSubLHS, AddSubRHS});
+ assert(!(NegateResult && TrailingZeroes) &&
+ "NegateResult and TrailingZeroes cannot both be true for now.");
+ // Negate the result.
+ if (NegateResult) {
+ B.buildSub(DstReg, B.buildConstant(Ty, 0), Res);
+ return;
+ }
+ // Shift the result.
+ if (TrailingZeroes) {
+ B.buildShl(DstReg, Res, B.buildConstant(LLT::scalar(64), TrailingZeroes));
+ return;
+ }
+ B.buildCopy(DstReg, Res.getReg(0));
+ };
return true;
}
-bool applyAArch64MulConstCombine(
- MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
- std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
- B.setInstrAndDebugLoc(MI);
- ApplyFn(B, MI.getOperand(0).getReg());
+bool applyAArch64MulConstCombine(
+ MachineInstr &MI, MachineRegisterInfo &MRI, MachineIRBuilder &B,
+ std::function<void(MachineIRBuilder &B, Register DstReg)> &ApplyFn) {
+ B.setInstrAndDebugLoc(MI);
+ ApplyFn(B, MI.getOperand(0).getReg());
MI.eraseFromParent();
return true;
}
@@ -348,7 +348,7 @@ INITIALIZE_PASS_END(AArch64PostLegalizerCombiner, DEBUG_TYPE,
false)
namespace llvm {
-FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
+FunctionPass *createAArch64PostLegalizerCombiner(bool IsOptNone) {
return new AArch64PostLegalizerCombiner(IsOptNone);
}
} // end namespace llvm