aboutsummaryrefslogtreecommitdiffstats
path: root/contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp
diff options
context:
space:
mode:
authorvitalyisaev <vitalyisaev@yandex-team.com>2023-06-29 10:00:50 +0300
committervitalyisaev <vitalyisaev@yandex-team.com>2023-06-29 10:00:50 +0300
commit6ffe9e53658409f212834330e13564e4952558f6 (patch)
tree85b1e00183517648b228aafa7c8fb07f5276f419 /contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp
parent726057070f9c5a91fc10fde0d5024913d10f1ab9 (diff)
downloadydb-6ffe9e53658409f212834330e13564e4952558f6.tar.gz
YQ Connector: support managed ClickHouse
Со стороны dqrun можно обратиться к инстансу коннектора, который работает на streaming стенде, и извлечь данные из облачного CH.
Diffstat (limited to 'contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp')
-rw-r--r--contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp1894
1 files changed, 1894 insertions, 0 deletions
diff --git a/contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp b/contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp
new file mode 100644
index 0000000000..c1fe10504e
--- /dev/null
+++ b/contrib/libs/llvm16/lib/Transforms/Utils/CodeExtractor.cpp
@@ -0,0 +1,1894 @@
+//===- CodeExtractor.cpp - Pull code region into a new function -----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the interface to tear out a code region, such as an
+// individual loop or a parallel section, into a new function, replacing it with
+// a call to the new function.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Transforms/Utils/CodeExtractor.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/AssumptionCache.h"
+#include "llvm/Analysis/BlockFrequencyInfo.h"
+#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
+#include "llvm/Analysis/BranchProbabilityInfo.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/IR/Argument.h"
+#include "llvm/IR/Attributes.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/CFG.h"
+#include "llvm/IR/Constant.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DIBuilder.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfo.h"
+#include "llvm/IR/DebugInfoMetadata.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/InstrTypes.h"
+#include "llvm/IR/Instruction.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/IR/Type.h"
+#include "llvm/IR/User.h"
+#include "llvm/IR/Value.h"
+#include "llvm/IR/Verifier.h"
+#include "llvm/Support/BlockFrequency.h"
+#include "llvm/Support/BranchProbability.h"
+#include "llvm/Support/Casting.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/raw_ostream.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include <cassert>
+#include <cstdint>
+#include <iterator>
+#include <map>
+#include <utility>
+#include <vector>
+
+using namespace llvm;
+using namespace llvm::PatternMatch;
+using ProfileCount = Function::ProfileCount;
+
+#define DEBUG_TYPE "code-extractor"
+
+// Provide a command-line option to aggregate function arguments into a struct
+// for functions produced by the code extractor. This is useful when converting
+// extracted functions to pthread-based code, as only one argument (void*) can
+// be passed in to pthread_create().
+static cl::opt<bool>
+AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
+ cl::desc("Aggregate arguments to code-extracted functions"));
+
+/// Test whether a block is valid for extraction.
+static bool isBlockValidForExtraction(const BasicBlock &BB,
+ const SetVector<BasicBlock *> &Result,
+ bool AllowVarArgs, bool AllowAlloca) {
+ // taking the address of a basic block moved to another function is illegal
+ if (BB.hasAddressTaken())
+ return false;
+
+ // don't hoist code that uses another basicblock address, as it's likely to
+ // lead to unexpected behavior, like cross-function jumps
+ SmallPtrSet<User const *, 16> Visited;
+ SmallVector<User const *, 16> ToVisit;
+
+ for (Instruction const &Inst : BB)
+ ToVisit.push_back(&Inst);
+
+ while (!ToVisit.empty()) {
+ User const *Curr = ToVisit.pop_back_val();
+ if (!Visited.insert(Curr).second)
+ continue;
+ if (isa<BlockAddress const>(Curr))
+ return false; // even a reference to self is likely to be not compatible
+
+ if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
+ continue;
+
+ for (auto const &U : Curr->operands()) {
+ if (auto *UU = dyn_cast<User>(U))
+ ToVisit.push_back(UU);
+ }
+ }
+
+ // If explicitly requested, allow vastart and alloca. For invoke instructions
+ // verify that extraction is valid.
+ for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
+ if (isa<AllocaInst>(I)) {
+ if (!AllowAlloca)
+ return false;
+ continue;
+ }
+
+ if (const auto *II = dyn_cast<InvokeInst>(I)) {
+ // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
+ // must be a part of the subgraph which is being extracted.
+ if (auto *UBB = II->getUnwindDest())
+ if (!Result.count(UBB))
+ return false;
+ continue;
+ }
+
+ // All catch handlers of a catchswitch instruction as well as the unwind
+ // destination must be in the subgraph.
+ if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
+ if (auto *UBB = CSI->getUnwindDest())
+ if (!Result.count(UBB))
+ return false;
+ for (const auto *HBB : CSI->handlers())
+ if (!Result.count(const_cast<BasicBlock*>(HBB)))
+ return false;
+ continue;
+ }
+
+ // Make sure that entire catch handler is within subgraph. It is sufficient
+ // to check that catch return's block is in the list.
+ if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
+ for (const auto *U : CPI->users())
+ if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
+ if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
+ return false;
+ continue;
+ }
+
+ // And do similar checks for cleanup handler - the entire handler must be
+ // in subgraph which is going to be extracted. For cleanup return should
+ // additionally check that the unwind destination is also in the subgraph.
+ if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
+ for (const auto *U : CPI->users())
+ if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
+ if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
+ return false;
+ continue;
+ }
+ if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
+ if (auto *UBB = CRI->getUnwindDest())
+ if (!Result.count(UBB))
+ return false;
+ continue;
+ }
+
+ if (const CallInst *CI = dyn_cast<CallInst>(I)) {
+ if (const Function *F = CI->getCalledFunction()) {
+ auto IID = F->getIntrinsicID();
+ if (IID == Intrinsic::vastart) {
+ if (AllowVarArgs)
+ continue;
+ else
+ return false;
+ }
+
+ // Currently, we miscompile outlined copies of eh_typid_for. There are
+ // proposals for fixing this in llvm.org/PR39545.
+ if (IID == Intrinsic::eh_typeid_for)
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+/// Build a set of blocks to extract if the input blocks are viable.
+static SetVector<BasicBlock *>
+buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
+ bool AllowVarArgs, bool AllowAlloca) {
+ assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
+ SetVector<BasicBlock *> Result;
+
+ // Loop over the blocks, adding them to our set-vector, and aborting with an
+ // empty set if we encounter invalid blocks.
+ for (BasicBlock *BB : BBs) {
+ // If this block is dead, don't process it.
+ if (DT && !DT->isReachableFromEntry(BB))
+ continue;
+
+ if (!Result.insert(BB))
+ llvm_unreachable("Repeated basic blocks in extraction input");
+ }
+
+ LLVM_DEBUG(dbgs() << "Region front block: " << Result.front()->getName()
+ << '\n');
+
+ for (auto *BB : Result) {
+ if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
+ return {};
+
+ // Make sure that the first block is not a landing pad.
+ if (BB == Result.front()) {
+ if (BB->isEHPad()) {
+ LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
+ return {};
+ }
+ continue;
+ }
+
+ // All blocks other than the first must not have predecessors outside of
+ // the subgraph which is being extracted.
+ for (auto *PBB : predecessors(BB))
+ if (!Result.count(PBB)) {
+ LLVM_DEBUG(dbgs() << "No blocks in this region may have entries from "
+ "outside the region except for the first block!\n"
+ << "Problematic source BB: " << BB->getName() << "\n"
+ << "Problematic destination BB: " << PBB->getName()
+ << "\n");
+ return {};
+ }
+ }
+
+ return Result;
+}
+
+CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
+ bool AggregateArgs, BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI, AssumptionCache *AC,
+ bool AllowVarArgs, bool AllowAlloca,
+ BasicBlock *AllocationBlock, std::string Suffix)
+ : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), AC(AC), AllocationBlock(AllocationBlock),
+ AllowVarArgs(AllowVarArgs),
+ Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
+ Suffix(Suffix) {}
+
+CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
+ BlockFrequencyInfo *BFI,
+ BranchProbabilityInfo *BPI, AssumptionCache *AC,
+ std::string Suffix)
+ : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
+ BPI(BPI), AC(AC), AllocationBlock(nullptr), AllowVarArgs(false),
+ Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
+ /* AllowVarArgs */ false,
+ /* AllowAlloca */ false)),
+ Suffix(Suffix) {}
+
+/// definedInRegion - Return true if the specified value is defined in the
+/// extracted region.
+static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ if (Blocks.count(I->getParent()))
+ return true;
+ return false;
+}
+
+/// definedInCaller - Return true if the specified value is defined in the
+/// function being code extracted, but not in the region being extracted.
+/// These values must be passed in as live-ins to the function.
+static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
+ if (isa<Argument>(V)) return true;
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ if (!Blocks.count(I->getParent()))
+ return true;
+ return false;
+}
+
+static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
+ BasicBlock *CommonExitBlock = nullptr;
+ auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
+ for (auto *Succ : successors(Block)) {
+ // Internal edges, ok.
+ if (Blocks.count(Succ))
+ continue;
+ if (!CommonExitBlock) {
+ CommonExitBlock = Succ;
+ continue;
+ }
+ if (CommonExitBlock != Succ)
+ return true;
+ }
+ return false;
+ };
+
+ if (any_of(Blocks, hasNonCommonExitSucc))
+ return nullptr;
+
+ return CommonExitBlock;
+}
+
+CodeExtractorAnalysisCache::CodeExtractorAnalysisCache(Function &F) {
+ for (BasicBlock &BB : F) {
+ for (Instruction &II : BB.instructionsWithoutDebug())
+ if (auto *AI = dyn_cast<AllocaInst>(&II))
+ Allocas.push_back(AI);
+
+ findSideEffectInfoForBlock(BB);
+ }
+}
+
+void CodeExtractorAnalysisCache::findSideEffectInfoForBlock(BasicBlock &BB) {
+ for (Instruction &II : BB.instructionsWithoutDebug()) {
+ unsigned Opcode = II.getOpcode();
+ Value *MemAddr = nullptr;
+ switch (Opcode) {
+ case Instruction::Store:
+ case Instruction::Load: {
+ if (Opcode == Instruction::Store) {
+ StoreInst *SI = cast<StoreInst>(&II);
+ MemAddr = SI->getPointerOperand();
+ } else {
+ LoadInst *LI = cast<LoadInst>(&II);
+ MemAddr = LI->getPointerOperand();
+ }
+ // Global variable can not be aliased with locals.
+ if (isa<Constant>(MemAddr))
+ break;
+ Value *Base = MemAddr->stripInBoundsConstantOffsets();
+ if (!isa<AllocaInst>(Base)) {
+ SideEffectingBlocks.insert(&BB);
+ return;
+ }
+ BaseMemAddrs[&BB].insert(Base);
+ break;
+ }
+ default: {
+ IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
+ if (IntrInst) {
+ if (IntrInst->isLifetimeStartOrEnd())
+ break;
+ SideEffectingBlocks.insert(&BB);
+ return;
+ }
+ // Treat all the other cases conservatively if it has side effects.
+ if (II.mayHaveSideEffects()) {
+ SideEffectingBlocks.insert(&BB);
+ return;
+ }
+ }
+ }
+ }
+}
+
+bool CodeExtractorAnalysisCache::doesBlockContainClobberOfAddr(
+ BasicBlock &BB, AllocaInst *Addr) const {
+ if (SideEffectingBlocks.count(&BB))
+ return true;
+ auto It = BaseMemAddrs.find(&BB);
+ if (It != BaseMemAddrs.end())
+ return It->second.count(Addr);
+ return false;
+}
+
+bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
+ const CodeExtractorAnalysisCache &CEAC, Instruction *Addr) const {
+ AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
+ Function *Func = (*Blocks.begin())->getParent();
+ for (BasicBlock &BB : *Func) {
+ if (Blocks.count(&BB))
+ continue;
+ if (CEAC.doesBlockContainClobberOfAddr(BB, AI))
+ return false;
+ }
+ return true;
+}
+
+BasicBlock *
+CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
+ BasicBlock *SinglePredFromOutlineRegion = nullptr;
+ assert(!Blocks.count(CommonExitBlock) &&
+ "Expect a block outside the region!");
+ for (auto *Pred : predecessors(CommonExitBlock)) {
+ if (!Blocks.count(Pred))
+ continue;
+ if (!SinglePredFromOutlineRegion) {
+ SinglePredFromOutlineRegion = Pred;
+ } else if (SinglePredFromOutlineRegion != Pred) {
+ SinglePredFromOutlineRegion = nullptr;
+ break;
+ }
+ }
+
+ if (SinglePredFromOutlineRegion)
+ return SinglePredFromOutlineRegion;
+
+#ifndef NDEBUG
+ auto getFirstPHI = [](BasicBlock *BB) {
+ BasicBlock::iterator I = BB->begin();
+ PHINode *FirstPhi = nullptr;
+ while (I != BB->end()) {
+ PHINode *Phi = dyn_cast<PHINode>(I);
+ if (!Phi)
+ break;
+ if (!FirstPhi) {
+ FirstPhi = Phi;
+ break;
+ }
+ }
+ return FirstPhi;
+ };
+ // If there are any phi nodes, the single pred either exists or has already
+ // be created before code extraction.
+ assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
+#endif
+
+ BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
+ CommonExitBlock->getFirstNonPHI()->getIterator());
+
+ for (BasicBlock *Pred :
+ llvm::make_early_inc_range(predecessors(CommonExitBlock))) {
+ if (Blocks.count(Pred))
+ continue;
+ Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
+ }
+ // Now add the old exit block to the outline region.
+ Blocks.insert(CommonExitBlock);
+ OldTargets.push_back(NewExitBlock);
+ return CommonExitBlock;
+}
+
+// Find the pair of life time markers for address 'Addr' that are either
+// defined inside the outline region or can legally be shrinkwrapped into the
+// outline region. If there are not other untracked uses of the address, return
+// the pair of markers if found; otherwise return a pair of nullptr.
+CodeExtractor::LifetimeMarkerInfo
+CodeExtractor::getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
+ Instruction *Addr,
+ BasicBlock *ExitBlock) const {
+ LifetimeMarkerInfo Info;
+
+ for (User *U : Addr->users()) {
+ IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
+ if (IntrInst) {
+ // We don't model addresses with multiple start/end markers, but the
+ // markers do not need to be in the region.
+ if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
+ if (Info.LifeStart)
+ return {};
+ Info.LifeStart = IntrInst;
+ continue;
+ }
+ if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
+ if (Info.LifeEnd)
+ return {};
+ Info.LifeEnd = IntrInst;
+ continue;
+ }
+ // At this point, permit debug uses outside of the region.
+ // This is fixed in a later call to fixupDebugInfoPostExtraction().
+ if (isa<DbgInfoIntrinsic>(IntrInst))
+ continue;
+ }
+ // Find untracked uses of the address, bail.
+ if (!definedInRegion(Blocks, U))
+ return {};
+ }
+
+ if (!Info.LifeStart || !Info.LifeEnd)
+ return {};
+
+ Info.SinkLifeStart = !definedInRegion(Blocks, Info.LifeStart);
+ Info.HoistLifeEnd = !definedInRegion(Blocks, Info.LifeEnd);
+ // Do legality check.
+ if ((Info.SinkLifeStart || Info.HoistLifeEnd) &&
+ !isLegalToShrinkwrapLifetimeMarkers(CEAC, Addr))
+ return {};
+
+ // Check to see if we have a place to do hoisting, if not, bail.
+ if (Info.HoistLifeEnd && !ExitBlock)
+ return {};
+
+ return Info;
+}
+
+void CodeExtractor::findAllocas(const CodeExtractorAnalysisCache &CEAC,
+ ValueSet &SinkCands, ValueSet &HoistCands,
+ BasicBlock *&ExitBlock) const {
+ Function *Func = (*Blocks.begin())->getParent();
+ ExitBlock = getCommonExitBlock(Blocks);
+
+ auto moveOrIgnoreLifetimeMarkers =
+ [&](const LifetimeMarkerInfo &LMI) -> bool {
+ if (!LMI.LifeStart)
+ return false;
+ if (LMI.SinkLifeStart) {
+ LLVM_DEBUG(dbgs() << "Sinking lifetime.start: " << *LMI.LifeStart
+ << "\n");
+ SinkCands.insert(LMI.LifeStart);
+ }
+ if (LMI.HoistLifeEnd) {
+ LLVM_DEBUG(dbgs() << "Hoisting lifetime.end: " << *LMI.LifeEnd << "\n");
+ HoistCands.insert(LMI.LifeEnd);
+ }
+ return true;
+ };
+
+ // Look up allocas in the original function in CodeExtractorAnalysisCache, as
+ // this is much faster than walking all the instructions.
+ for (AllocaInst *AI : CEAC.getAllocas()) {
+ BasicBlock *BB = AI->getParent();
+ if (Blocks.count(BB))
+ continue;
+
+ // As a prior call to extractCodeRegion() may have shrinkwrapped the alloca,
+ // check whether it is actually still in the original function.
+ Function *AIFunc = BB->getParent();
+ if (AIFunc != Func)
+ continue;
+
+ LifetimeMarkerInfo MarkerInfo = getLifetimeMarkers(CEAC, AI, ExitBlock);
+ bool Moved = moveOrIgnoreLifetimeMarkers(MarkerInfo);
+ if (Moved) {
+ LLVM_DEBUG(dbgs() << "Sinking alloca: " << *AI << "\n");
+ SinkCands.insert(AI);
+ continue;
+ }
+
+ // Find bitcasts in the outlined region that have lifetime marker users
+ // outside that region. Replace the lifetime marker use with an
+ // outside region bitcast to avoid unnecessary alloca/reload instructions
+ // and extra lifetime markers.
+ SmallVector<Instruction *, 2> LifetimeBitcastUsers;
+ for (User *U : AI->users()) {
+ if (!definedInRegion(Blocks, U))
+ continue;
+
+ if (U->stripInBoundsConstantOffsets() != AI)
+ continue;
+
+ Instruction *Bitcast = cast<Instruction>(U);
+ for (User *BU : Bitcast->users()) {
+ IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(BU);
+ if (!IntrInst)
+ continue;
+
+ if (!IntrInst->isLifetimeStartOrEnd())
+ continue;
+
+ if (definedInRegion(Blocks, IntrInst))
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Replace use of extracted region bitcast"
+ << *Bitcast << " in out-of-region lifetime marker "
+ << *IntrInst << "\n");
+ LifetimeBitcastUsers.push_back(IntrInst);
+ }
+ }
+
+ for (Instruction *I : LifetimeBitcastUsers) {
+ Module *M = AIFunc->getParent();
+ LLVMContext &Ctx = M->getContext();
+ auto *Int8PtrTy = Type::getInt8PtrTy(Ctx);
+ CastInst *CastI =
+ CastInst::CreatePointerCast(AI, Int8PtrTy, "lt.cast", I);
+ I->replaceUsesOfWith(I->getOperand(1), CastI);
+ }
+
+ // Follow any bitcasts.
+ SmallVector<Instruction *, 2> Bitcasts;
+ SmallVector<LifetimeMarkerInfo, 2> BitcastLifetimeInfo;
+ for (User *U : AI->users()) {
+ if (U->stripInBoundsConstantOffsets() == AI) {
+ Instruction *Bitcast = cast<Instruction>(U);
+ LifetimeMarkerInfo LMI = getLifetimeMarkers(CEAC, Bitcast, ExitBlock);
+ if (LMI.LifeStart) {
+ Bitcasts.push_back(Bitcast);
+ BitcastLifetimeInfo.push_back(LMI);
+ continue;
+ }
+ }
+
+ // Found unknown use of AI.
+ if (!definedInRegion(Blocks, U)) {
+ Bitcasts.clear();
+ break;
+ }
+ }
+
+ // Either no bitcasts reference the alloca or there are unknown uses.
+ if (Bitcasts.empty())
+ continue;
+
+ LLVM_DEBUG(dbgs() << "Sinking alloca (via bitcast): " << *AI << "\n");
+ SinkCands.insert(AI);
+ for (unsigned I = 0, E = Bitcasts.size(); I != E; ++I) {
+ Instruction *BitcastAddr = Bitcasts[I];
+ const LifetimeMarkerInfo &LMI = BitcastLifetimeInfo[I];
+ assert(LMI.LifeStart &&
+ "Unsafe to sink bitcast without lifetime markers");
+ moveOrIgnoreLifetimeMarkers(LMI);
+ if (!definedInRegion(Blocks, BitcastAddr)) {
+ LLVM_DEBUG(dbgs() << "Sinking bitcast-of-alloca: " << *BitcastAddr
+ << "\n");
+ SinkCands.insert(BitcastAddr);
+ }
+ }
+ }
+}
+
+bool CodeExtractor::isEligible() const {
+ if (Blocks.empty())
+ return false;
+ BasicBlock *Header = *Blocks.begin();
+ Function *F = Header->getParent();
+
+ // For functions with varargs, check that varargs handling is only done in the
+ // outlined function, i.e vastart and vaend are only used in outlined blocks.
+ if (AllowVarArgs && F->getFunctionType()->isVarArg()) {
+ auto containsVarArgIntrinsic = [](const Instruction &I) {
+ if (const CallInst *CI = dyn_cast<CallInst>(&I))
+ if (const Function *Callee = CI->getCalledFunction())
+ return Callee->getIntrinsicID() == Intrinsic::vastart ||
+ Callee->getIntrinsicID() == Intrinsic::vaend;
+ return false;
+ };
+
+ for (auto &BB : *F) {
+ if (Blocks.count(&BB))
+ continue;
+ if (llvm::any_of(BB, containsVarArgIntrinsic))
+ return false;
+ }
+ }
+ return true;
+}
+
+void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
+ const ValueSet &SinkCands) const {
+ for (BasicBlock *BB : Blocks) {
+ // If a used value is defined outside the region, it's an input. If an
+ // instruction is used outside the region, it's an output.
+ for (Instruction &II : *BB) {
+ for (auto &OI : II.operands()) {
+ Value *V = OI;
+ if (!SinkCands.count(V) && definedInCaller(Blocks, V))
+ Inputs.insert(V);
+ }
+
+ for (User *U : II.users())
+ if (!definedInRegion(Blocks, U)) {
+ Outputs.insert(&II);
+ break;
+ }
+ }
+ }
+}
+
+/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
+/// of the region, we need to split the entry block of the region so that the
+/// PHI node is easier to deal with.
+void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
+ unsigned NumPredsFromRegion = 0;
+ unsigned NumPredsOutsideRegion = 0;
+
+ if (Header != &Header->getParent()->getEntryBlock()) {
+ PHINode *PN = dyn_cast<PHINode>(Header->begin());
+ if (!PN) return; // No PHI nodes.
+
+ // If the header node contains any PHI nodes, check to see if there is more
+ // than one entry from outside the region. If so, we need to sever the
+ // header block into two.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (Blocks.count(PN->getIncomingBlock(i)))
+ ++NumPredsFromRegion;
+ else
+ ++NumPredsOutsideRegion;
+
+ // If there is one (or fewer) predecessor from outside the region, we don't
+ // need to do anything special.
+ if (NumPredsOutsideRegion <= 1) return;
+ }
+
+ // Otherwise, we need to split the header block into two pieces: one
+ // containing PHI nodes merging values from outside of the region, and a
+ // second that contains all of the code for the block and merges back any
+ // incoming values from inside of the region.
+ BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT);
+
+ // We only want to code extract the second block now, and it becomes the new
+ // header of the region.
+ BasicBlock *OldPred = Header;
+ Blocks.remove(OldPred);
+ Blocks.insert(NewBB);
+ Header = NewBB;
+
+ // Okay, now we need to adjust the PHI nodes and any branches from within the
+ // region to go to the new header block instead of the old header block.
+ if (NumPredsFromRegion) {
+ PHINode *PN = cast<PHINode>(OldPred->begin());
+ // Loop over all of the predecessors of OldPred that are in the region,
+ // changing them to branch to NewBB instead.
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (Blocks.count(PN->getIncomingBlock(i))) {
+ Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
+ TI->replaceUsesOfWith(OldPred, NewBB);
+ }
+
+ // Okay, everything within the region is now branching to the right block, we
+ // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
+ BasicBlock::iterator AfterPHIs;
+ for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
+ PHINode *PN = cast<PHINode>(AfterPHIs);
+ // Create a new PHI node in the new region, which has an incoming value
+ // from OldPred of PN.
+ PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
+ PN->getName() + ".ce", &NewBB->front());
+ PN->replaceAllUsesWith(NewPN);
+ NewPN->addIncoming(PN, OldPred);
+
+ // Loop over all of the incoming value in PN, moving them to NewPN if they
+ // are from the extracted region.
+ for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
+ if (Blocks.count(PN->getIncomingBlock(i))) {
+ NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
+ PN->removeIncomingValue(i);
+ --i;
+ }
+ }
+ }
+ }
+}
+
+/// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
+/// outlined region, we split these PHIs on two: one with inputs from region
+/// and other with remaining incoming blocks; then first PHIs are placed in
+/// outlined region.
+void CodeExtractor::severSplitPHINodesOfExits(
+ const SmallPtrSetImpl<BasicBlock *> &Exits) {
+ for (BasicBlock *ExitBB : Exits) {
+ BasicBlock *NewBB = nullptr;
+
+ for (PHINode &PN : ExitBB->phis()) {
+ // Find all incoming values from the outlining region.
+ SmallVector<unsigned, 2> IncomingVals;
+ for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
+ if (Blocks.count(PN.getIncomingBlock(i)))
+ IncomingVals.push_back(i);
+
+ // Do not process PHI if there is one (or fewer) predecessor from region.
+ // If PHI has exactly one predecessor from region, only this one incoming
+ // will be replaced on codeRepl block, so it should be safe to skip PHI.
+ if (IncomingVals.size() <= 1)
+ continue;
+
+ // Create block for new PHIs and add it to the list of outlined if it
+ // wasn't done before.
+ if (!NewBB) {
+ NewBB = BasicBlock::Create(ExitBB->getContext(),
+ ExitBB->getName() + ".split",
+ ExitBB->getParent(), ExitBB);
+ SmallVector<BasicBlock *, 4> Preds(predecessors(ExitBB));
+ for (BasicBlock *PredBB : Preds)
+ if (Blocks.count(PredBB))
+ PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
+ BranchInst::Create(ExitBB, NewBB);
+ Blocks.insert(NewBB);
+ }
+
+ // Split this PHI.
+ PHINode *NewPN =
+ PHINode::Create(PN.getType(), IncomingVals.size(),
+ PN.getName() + ".ce", NewBB->getFirstNonPHI());
+ for (unsigned i : IncomingVals)
+ NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
+ for (unsigned i : reverse(IncomingVals))
+ PN.removeIncomingValue(i, false);
+ PN.addIncoming(NewPN, NewBB);
+ }
+ }
+}
+
+void CodeExtractor::splitReturnBlocks() {
+ for (BasicBlock *Block : Blocks)
+ if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
+ BasicBlock *New =
+ Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
+ if (DT) {
+ // Old dominates New. New node dominates all other nodes dominated
+ // by Old.
+ DomTreeNode *OldNode = DT->getNode(Block);
+ SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
+ OldNode->end());
+
+ DomTreeNode *NewNode = DT->addNewBlock(New, Block);
+
+ for (DomTreeNode *I : Children)
+ DT->changeImmediateDominator(I, NewNode);
+ }
+ }
+}
+
+/// constructFunction - make a function based on inputs and outputs, as follows:
+/// f(in0, ..., inN, out0, ..., outN)
+Function *CodeExtractor::constructFunction(const ValueSet &inputs,
+ const ValueSet &outputs,
+ BasicBlock *header,
+ BasicBlock *newRootNode,
+ BasicBlock *newHeader,
+ Function *oldFunction,
+ Module *M) {
+ LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
+ LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
+
+ // This function returns unsigned, outputs will go back by reference.
+ switch (NumExitBlocks) {
+ case 0:
+ case 1: RetTy = Type::getVoidTy(header->getContext()); break;
+ case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
+ default: RetTy = Type::getInt16Ty(header->getContext()); break;
+ }
+
+ std::vector<Type *> ParamTy;
+ std::vector<Type *> AggParamTy;
+ ValueSet StructValues;
+ const DataLayout &DL = M->getDataLayout();
+
+ // Add the types of the input values to the function's argument list
+ for (Value *value : inputs) {
+ LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
+ AggParamTy.push_back(value->getType());
+ StructValues.insert(value);
+ } else
+ ParamTy.push_back(value->getType());
+ }
+
+ // Add the types of the output values to the function's argument list.
+ for (Value *output : outputs) {
+ LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+ AggParamTy.push_back(output->getType());
+ StructValues.insert(output);
+ } else
+ ParamTy.push_back(
+ PointerType::get(output->getType(), DL.getAllocaAddrSpace()));
+ }
+
+ assert(
+ (ParamTy.size() + AggParamTy.size()) ==
+ (inputs.size() + outputs.size()) &&
+ "Number of scalar and aggregate params does not match inputs, outputs");
+ assert((StructValues.empty() || AggregateArgs) &&
+ "Expeced StructValues only with AggregateArgs set");
+
+ // Concatenate scalar and aggregate params in ParamTy.
+ size_t NumScalarParams = ParamTy.size();
+ StructType *StructTy = nullptr;
+ if (AggregateArgs && !AggParamTy.empty()) {
+ StructTy = StructType::get(M->getContext(), AggParamTy);
+ ParamTy.push_back(PointerType::get(StructTy, DL.getAllocaAddrSpace()));
+ }
+
+ LLVM_DEBUG({
+ dbgs() << "Function type: " << *RetTy << " f(";
+ for (Type *i : ParamTy)
+ dbgs() << *i << ", ";
+ dbgs() << ")\n";
+ });
+
+ FunctionType *funcType = FunctionType::get(
+ RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
+
+ std::string SuffixToUse =
+ Suffix.empty()
+ ? (header->getName().empty() ? "extracted" : header->getName().str())
+ : Suffix;
+ // Create the new function
+ Function *newFunction = Function::Create(
+ funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
+ oldFunction->getName() + "." + SuffixToUse, M);
+
+ // Inherit all of the target dependent attributes and white-listed
+ // target independent attributes.
+ // (e.g. If the extracted region contains a call to an x86.sse
+ // instruction we need to make sure that the extracted region has the
+ // "target-features" attribute allowing it to be lowered.
+ // FIXME: This should be changed to check to see if a specific
+ // attribute can not be inherited.
+ for (const auto &Attr : oldFunction->getAttributes().getFnAttrs()) {
+ if (Attr.isStringAttribute()) {
+ if (Attr.getKindAsString() == "thunk")
+ continue;
+ } else
+ switch (Attr.getKindAsEnum()) {
+ // Those attributes cannot be propagated safely. Explicitly list them
+ // here so we get a warning if new attributes are added.
+ case Attribute::AllocSize:
+ case Attribute::Builtin:
+ case Attribute::Convergent:
+ case Attribute::JumpTable:
+ case Attribute::Naked:
+ case Attribute::NoBuiltin:
+ case Attribute::NoMerge:
+ case Attribute::NoReturn:
+ case Attribute::NoSync:
+ case Attribute::ReturnsTwice:
+ case Attribute::Speculatable:
+ case Attribute::StackAlignment:
+ case Attribute::WillReturn:
+ case Attribute::AllocKind:
+ case Attribute::PresplitCoroutine:
+ case Attribute::Memory:
+ continue;
+ // Those attributes should be safe to propagate to the extracted function.
+ case Attribute::AlwaysInline:
+ case Attribute::Cold:
+ case Attribute::DisableSanitizerInstrumentation:
+ case Attribute::FnRetThunkExtern:
+ case Attribute::Hot:
+ case Attribute::NoRecurse:
+ case Attribute::InlineHint:
+ case Attribute::MinSize:
+ case Attribute::NoCallback:
+ case Attribute::NoDuplicate:
+ case Attribute::NoFree:
+ case Attribute::NoImplicitFloat:
+ case Attribute::NoInline:
+ case Attribute::NonLazyBind:
+ case Attribute::NoRedZone:
+ case Attribute::NoUnwind:
+ case Attribute::NoSanitizeBounds:
+ case Attribute::NoSanitizeCoverage:
+ case Attribute::NullPointerIsValid:
+ case Attribute::OptForFuzzing:
+ case Attribute::OptimizeNone:
+ case Attribute::OptimizeForSize:
+ case Attribute::SafeStack:
+ case Attribute::ShadowCallStack:
+ case Attribute::SanitizeAddress:
+ case Attribute::SanitizeMemory:
+ case Attribute::SanitizeThread:
+ case Attribute::SanitizeHWAddress:
+ case Attribute::SanitizeMemTag:
+ case Attribute::SpeculativeLoadHardening:
+ case Attribute::StackProtect:
+ case Attribute::StackProtectReq:
+ case Attribute::StackProtectStrong:
+ case Attribute::StrictFP:
+ case Attribute::UWTable:
+ case Attribute::VScaleRange:
+ case Attribute::NoCfCheck:
+ case Attribute::MustProgress:
+ case Attribute::NoProfile:
+ case Attribute::SkipProfile:
+ break;
+ // These attributes cannot be applied to functions.
+ case Attribute::Alignment:
+ case Attribute::AllocatedPointer:
+ case Attribute::AllocAlign:
+ case Attribute::ByVal:
+ case Attribute::Dereferenceable:
+ case Attribute::DereferenceableOrNull:
+ case Attribute::ElementType:
+ case Attribute::InAlloca:
+ case Attribute::InReg:
+ case Attribute::Nest:
+ case Attribute::NoAlias:
+ case Attribute::NoCapture:
+ case Attribute::NoUndef:
+ case Attribute::NonNull:
+ case Attribute::Preallocated:
+ case Attribute::ReadNone:
+ case Attribute::ReadOnly:
+ case Attribute::Returned:
+ case Attribute::SExt:
+ case Attribute::StructRet:
+ case Attribute::SwiftError:
+ case Attribute::SwiftSelf:
+ case Attribute::SwiftAsync:
+ case Attribute::ZExt:
+ case Attribute::ImmArg:
+ case Attribute::ByRef:
+ case Attribute::WriteOnly:
+ // These are not really attributes.
+ case Attribute::None:
+ case Attribute::EndAttrKinds:
+ case Attribute::EmptyKey:
+ case Attribute::TombstoneKey:
+ llvm_unreachable("Not a function attribute");
+ }
+
+ newFunction->addFnAttr(Attr);
+ }
+ newFunction->insert(newFunction->end(), newRootNode);
+
+ // Create scalar and aggregate iterators to name all of the arguments we
+ // inserted.
+ Function::arg_iterator ScalarAI = newFunction->arg_begin();
+ Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
+
+ // Rewrite all users of the inputs in the extracted region to use the
+ // arguments (or appropriate addressing into struct) instead.
+ for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
+ Value *RewriteVal;
+ if (AggregateArgs && StructValues.contains(inputs[i])) {
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
+ Instruction *TI = newFunction->begin()->getTerminator();
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
+ RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
+ "loadgep_" + inputs[i]->getName(), TI);
+ ++aggIdx;
+ } else
+ RewriteVal = &*ScalarAI++;
+
+ std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
+ for (User *use : Users)
+ if (Instruction *inst = dyn_cast<Instruction>(use))
+ if (Blocks.count(inst->getParent()))
+ inst->replaceUsesOfWith(inputs[i], RewriteVal);
+ }
+
+ // Set names for input and output arguments.
+ if (NumScalarParams) {
+ ScalarAI = newFunction->arg_begin();
+ for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
+ if (!StructValues.contains(inputs[i]))
+ ScalarAI->setName(inputs[i]->getName());
+ for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
+ if (!StructValues.contains(outputs[i]))
+ ScalarAI->setName(outputs[i]->getName() + ".out");
+ }
+
+ // Rewrite branches to basic blocks outside of the loop to new dummy blocks
+ // within the new function. This must be done before we lose track of which
+ // blocks were originally in the code region.
+ std::vector<User *> Users(header->user_begin(), header->user_end());
+ for (auto &U : Users)
+ // The BasicBlock which contains the branch is not in the region
+ // modify the branch target to a new block
+ if (Instruction *I = dyn_cast<Instruction>(U))
+ if (I->isTerminator() && I->getFunction() == oldFunction &&
+ !Blocks.count(I->getParent()))
+ I->replaceUsesOfWith(header, newHeader);
+
+ return newFunction;
+}
+
+/// Erase lifetime.start markers which reference inputs to the extraction
+/// region, and insert the referenced memory into \p LifetimesStart.
+///
+/// The extraction region is defined by a set of blocks (\p Blocks), and a set
+/// of allocas which will be moved from the caller function into the extracted
+/// function (\p SunkAllocas).
+static void eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
+ const SetVector<Value *> &SunkAllocas,
+ SetVector<Value *> &LifetimesStart) {
+ for (BasicBlock *BB : Blocks) {
+ for (Instruction &I : llvm::make_early_inc_range(*BB)) {
+ auto *II = dyn_cast<IntrinsicInst>(&I);
+ if (!II || !II->isLifetimeStartOrEnd())
+ continue;
+
+ // Get the memory operand of the lifetime marker. If the underlying
+ // object is a sunk alloca, or is otherwise defined in the extraction
+ // region, the lifetime marker must not be erased.
+ Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
+ if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
+ continue;
+
+ if (II->getIntrinsicID() == Intrinsic::lifetime_start)
+ LifetimesStart.insert(Mem);
+ II->eraseFromParent();
+ }
+ }
+}
+
+/// Insert lifetime start/end markers surrounding the call to the new function
+/// for objects defined in the caller.
+static void insertLifetimeMarkersSurroundingCall(
+ Module *M, ArrayRef<Value *> LifetimesStart, ArrayRef<Value *> LifetimesEnd,
+ CallInst *TheCall) {
+ LLVMContext &Ctx = M->getContext();
+ auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
+ auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
+ Instruction *Term = TheCall->getParent()->getTerminator();
+
+ // The memory argument to a lifetime marker must be a i8*. Cache any bitcasts
+ // needed to satisfy this requirement so they may be reused.
+ DenseMap<Value *, Value *> Bitcasts;
+
+ // Emit lifetime markers for the pointers given in \p Objects. Insert the
+ // markers before the call if \p InsertBefore, and after the call otherwise.
+ auto insertMarkers = [&](Function *MarkerFunc, ArrayRef<Value *> Objects,
+ bool InsertBefore) {
+ for (Value *Mem : Objects) {
+ assert((!isa<Instruction>(Mem) || cast<Instruction>(Mem)->getFunction() ==
+ TheCall->getFunction()) &&
+ "Input memory not defined in original function");
+ Value *&MemAsI8Ptr = Bitcasts[Mem];
+ if (!MemAsI8Ptr) {
+ if (Mem->getType() == Int8PtrTy)
+ MemAsI8Ptr = Mem;
+ else
+ MemAsI8Ptr =
+ CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
+ }
+
+ auto Marker = CallInst::Create(MarkerFunc, {NegativeOne, MemAsI8Ptr});
+ if (InsertBefore)
+ Marker->insertBefore(TheCall);
+ else
+ Marker->insertBefore(Term);
+ }
+ };
+
+ if (!LifetimesStart.empty()) {
+ auto StartFn = llvm::Intrinsic::getDeclaration(
+ M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
+ insertMarkers(StartFn, LifetimesStart, /*InsertBefore=*/true);
+ }
+
+ if (!LifetimesEnd.empty()) {
+ auto EndFn = llvm::Intrinsic::getDeclaration(
+ M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
+ insertMarkers(EndFn, LifetimesEnd, /*InsertBefore=*/false);
+ }
+}
+
+/// emitCallAndSwitchStatement - This method sets up the caller side by adding
+/// the call instruction, splitting any PHI nodes in the header block as
+/// necessary.
+CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
+ BasicBlock *codeReplacer,
+ ValueSet &inputs,
+ ValueSet &outputs) {
+ // Emit a call to the new function, passing in: *pointer to struct (if
+ // aggregating parameters), or plan inputs and allocated memory for outputs
+ std::vector<Value *> params, ReloadOutputs, Reloads;
+ ValueSet StructValues;
+
+ Module *M = newFunction->getParent();
+ LLVMContext &Context = M->getContext();
+ const DataLayout &DL = M->getDataLayout();
+ CallInst *call = nullptr;
+
+ // Add inputs as params, or to be filled into the struct
+ unsigned ScalarInputArgNo = 0;
+ SmallVector<unsigned, 1> SwiftErrorArgs;
+ for (Value *input : inputs) {
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
+ StructValues.insert(input);
+ else {
+ params.push_back(input);
+ if (input->isSwiftError())
+ SwiftErrorArgs.push_back(ScalarInputArgNo);
+ }
+ ++ScalarInputArgNo;
+ }
+
+ // Create allocas for the outputs
+ unsigned ScalarOutputArgNo = 0;
+ for (Value *output : outputs) {
+ if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
+ StructValues.insert(output);
+ } else {
+ AllocaInst *alloca =
+ new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
+ nullptr, output->getName() + ".loc",
+ &codeReplacer->getParent()->front().front());
+ ReloadOutputs.push_back(alloca);
+ params.push_back(alloca);
+ ++ScalarOutputArgNo;
+ }
+ }
+
+ StructType *StructArgTy = nullptr;
+ AllocaInst *Struct = nullptr;
+ unsigned NumAggregatedInputs = 0;
+ if (AggregateArgs && !StructValues.empty()) {
+ std::vector<Type *> ArgTypes;
+ for (Value *V : StructValues)
+ ArgTypes.push_back(V->getType());
+
+ // Allocate a struct at the beginning of this function
+ StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
+ Struct = new AllocaInst(
+ StructArgTy, DL.getAllocaAddrSpace(), nullptr, "structArg",
+ AllocationBlock ? &*AllocationBlock->getFirstInsertionPt()
+ : &codeReplacer->getParent()->front().front());
+ params.push_back(Struct);
+
+ // Store aggregated inputs in the struct.
+ for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
+ if (inputs.contains(StructValues[i])) {
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
+ GEP->insertInto(codeReplacer, codeReplacer->end());
+ new StoreInst(StructValues[i], GEP, codeReplacer);
+ NumAggregatedInputs++;
+ }
+ }
+ }
+
+ // Emit the call to the function
+ call = CallInst::Create(newFunction, params,
+ NumExitBlocks > 1 ? "targetBlock" : "");
+ // Add debug location to the new call, if the original function has debug
+ // info. In that case, the terminator of the entry block of the extracted
+ // function contains the first debug location of the extracted function,
+ // set in extractCodeRegion.
+ if (codeReplacer->getParent()->getSubprogram()) {
+ if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
+ call->setDebugLoc(DL);
+ }
+ call->insertInto(codeReplacer, codeReplacer->end());
+
+ // Set swifterror parameter attributes.
+ for (unsigned SwiftErrArgNo : SwiftErrorArgs) {
+ call->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
+ newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
+ }
+
+ // Reload the outputs passed in by reference, use the struct if output is in
+ // the aggregate or reload from the scalar argument.
+ for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
+ aggIdx = NumAggregatedInputs;
+ i != e; ++i) {
+ Value *Output = nullptr;
+ if (AggregateArgs && StructValues.contains(outputs[i])) {
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
+ GEP->insertInto(codeReplacer, codeReplacer->end());
+ Output = GEP;
+ ++aggIdx;
+ } else {
+ Output = ReloadOutputs[scalarIdx];
+ ++scalarIdx;
+ }
+ LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
+ outputs[i]->getName() + ".reload",
+ codeReplacer);
+ Reloads.push_back(load);
+ std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
+ for (User *U : Users) {
+ Instruction *inst = cast<Instruction>(U);
+ if (!Blocks.count(inst->getParent()))
+ inst->replaceUsesOfWith(outputs[i], load);
+ }
+ }
+
+ // Now we can emit a switch statement using the call as a value.
+ SwitchInst *TheSwitch =
+ SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
+ codeReplacer, 0, codeReplacer);
+
+ // Since there may be multiple exits from the original region, make the new
+ // function return an unsigned, switch on that number. This loop iterates
+ // over all of the blocks in the extracted region, updating any terminator
+ // instructions in the to-be-extracted region that branch to blocks that are
+ // not in the region to be extracted.
+ std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
+
+ // Iterate over the previously collected targets, and create new blocks inside
+ // the function to branch to.
+ unsigned switchVal = 0;
+ for (BasicBlock *OldTarget : OldTargets) {
+ if (Blocks.count(OldTarget))
+ continue;
+ BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
+ if (NewTarget)
+ continue;
+
+ // If we don't already have an exit stub for this non-extracted
+ // destination, create one now!
+ NewTarget = BasicBlock::Create(Context,
+ OldTarget->getName() + ".exitStub",
+ newFunction);
+ unsigned SuccNum = switchVal++;
+
+ Value *brVal = nullptr;
+ assert(NumExitBlocks < 0xffff && "too many exit blocks for switch");
+ switch (NumExitBlocks) {
+ case 0:
+ case 1: break; // No value needed.
+ case 2: // Conditional branch, return a bool
+ brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
+ break;
+ default:
+ brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
+ break;
+ }
+
+ ReturnInst::Create(Context, brVal, NewTarget);
+
+ // Update the switch instruction.
+ TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
+ SuccNum),
+ OldTarget);
+ }
+
+ for (BasicBlock *Block : Blocks) {
+ Instruction *TI = Block->getTerminator();
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
+ if (Blocks.count(TI->getSuccessor(i)))
+ continue;
+ BasicBlock *OldTarget = TI->getSuccessor(i);
+ // add a new basic block which returns the appropriate value
+ BasicBlock *NewTarget = ExitBlockMap[OldTarget];
+ assert(NewTarget && "Unknown target block!");
+
+ // rewrite the original branch instruction with this new target
+ TI->setSuccessor(i, NewTarget);
+ }
+ }
+
+ // Store the arguments right after the definition of output value.
+ // This should be proceeded after creating exit stubs to be ensure that invoke
+ // result restore will be placed in the outlined function.
+ Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
+ std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
+ Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
+ std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
+
+ for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
+ ++i) {
+ auto *OutI = dyn_cast<Instruction>(outputs[i]);
+ if (!OutI)
+ continue;
+
+ // Find proper insertion point.
+ BasicBlock::iterator InsertPt;
+ // In case OutI is an invoke, we insert the store at the beginning in the
+ // 'normal destination' BB. Otherwise we insert the store right after OutI.
+ if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
+ InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
+ else if (auto *Phi = dyn_cast<PHINode>(OutI))
+ InsertPt = Phi->getParent()->getFirstInsertionPt();
+ else
+ InsertPt = std::next(OutI->getIterator());
+
+ Instruction *InsertBefore = &*InsertPt;
+ assert((InsertBefore->getFunction() == newFunction ||
+ Blocks.count(InsertBefore->getParent())) &&
+ "InsertPt should be in new function");
+ if (AggregateArgs && StructValues.contains(outputs[i])) {
+ assert(AggOutputArgBegin != newFunction->arg_end() &&
+ "Number of aggregate output arguments should match "
+ "the number of defined values");
+ Value *Idx[2];
+ Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
+ Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
+ GetElementPtrInst *GEP = GetElementPtrInst::Create(
+ StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
+ InsertBefore);
+ new StoreInst(outputs[i], GEP, InsertBefore);
+ ++aggIdx;
+ // Since there should be only one struct argument aggregating
+ // all the output values, we shouldn't increment AggOutputArgBegin, which
+ // always points to the struct argument, in this case.
+ } else {
+ assert(ScalarOutputArgBegin != newFunction->arg_end() &&
+ "Number of scalar output arguments should match "
+ "the number of defined values");
+ new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
+ ++ScalarOutputArgBegin;
+ }
+ }
+
+ // Now that we've done the deed, simplify the switch instruction.
+ Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
+ switch (NumExitBlocks) {
+ case 0:
+ // There are no successors (the block containing the switch itself), which
+ // means that previously this was the last part of the function, and hence
+ // this should be rewritten as a `ret'
+
+ // Check if the function should return a value
+ if (OldFnRetTy->isVoidTy()) {
+ ReturnInst::Create(Context, nullptr, TheSwitch); // Return void
+ } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
+ // return what we have
+ ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
+ } else {
+ // Otherwise we must have code extracted an unwind or something, just
+ // return whatever we want.
+ ReturnInst::Create(Context,
+ Constant::getNullValue(OldFnRetTy), TheSwitch);
+ }
+
+ TheSwitch->eraseFromParent();
+ break;
+ case 1:
+ // Only a single destination, change the switch into an unconditional
+ // branch.
+ BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
+ TheSwitch->eraseFromParent();
+ break;
+ case 2:
+ BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
+ call, TheSwitch);
+ TheSwitch->eraseFromParent();
+ break;
+ default:
+ // Otherwise, make the default destination of the switch instruction be one
+ // of the other successors.
+ TheSwitch->setCondition(call);
+ TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
+ // Remove redundant case
+ TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
+ break;
+ }
+
+ // Insert lifetime markers around the reloads of any output values. The
+ // allocas output values are stored in are only in-use in the codeRepl block.
+ insertLifetimeMarkersSurroundingCall(M, ReloadOutputs, ReloadOutputs, call);
+
+ return call;
+}
+
+void CodeExtractor::moveCodeToFunction(Function *newFunction) {
+ auto newFuncIt = newFunction->front().getIterator();
+ for (BasicBlock *Block : Blocks) {
+ // Delete the basic block from the old function, and the list of blocks
+ Block->removeFromParent();
+
+ // Insert this basic block into the new function
+ // Insert the original blocks after the entry block created
+ // for the new function. The entry block may be followed
+ // by a set of exit blocks at this point, but these exit
+ // blocks better be placed at the end of the new function.
+ newFuncIt = newFunction->insert(std::next(newFuncIt), Block);
+ }
+}
+
+void CodeExtractor::calculateNewCallTerminatorWeights(
+ BasicBlock *CodeReplacer,
+ DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
+ BranchProbabilityInfo *BPI) {
+ using Distribution = BlockFrequencyInfoImplBase::Distribution;
+ using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
+
+ // Update the branch weights for the exit block.
+ Instruction *TI = CodeReplacer->getTerminator();
+ SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
+
+ // Block Frequency distribution with dummy node.
+ Distribution BranchDist;
+
+ SmallVector<BranchProbability, 4> EdgeProbabilities(
+ TI->getNumSuccessors(), BranchProbability::getUnknown());
+
+ // Add each of the frequencies of the successors.
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
+ BlockNode ExitNode(i);
+ uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
+ if (ExitFreq != 0)
+ BranchDist.addExit(ExitNode, ExitFreq);
+ else
+ EdgeProbabilities[i] = BranchProbability::getZero();
+ }
+
+ // Check for no total weight.
+ if (BranchDist.Total == 0) {
+ BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities);
+ return;
+ }
+
+ // Normalize the distribution so that they can fit in unsigned.
+ BranchDist.normalize();
+
+ // Create normalized branch weights and set the metadata.
+ for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
+ const auto &Weight = BranchDist.Weights[I];
+
+ // Get the weight and update the current BFI.
+ BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
+ BranchProbability BP(Weight.Amount, BranchDist.Total);
+ EdgeProbabilities[Weight.TargetNode.Index] = BP;
+ }
+ BPI->setEdgeProbability(CodeReplacer, EdgeProbabilities);
+ TI->setMetadata(
+ LLVMContext::MD_prof,
+ MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
+}
+
+/// Erase debug info intrinsics which refer to values in \p F but aren't in
+/// \p F.
+static void eraseDebugIntrinsicsWithNonLocalRefs(Function &F) {
+ for (Instruction &I : instructions(F)) {
+ SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
+ findDbgUsers(DbgUsers, &I);
+ for (DbgVariableIntrinsic *DVI : DbgUsers)
+ if (DVI->getFunction() != &F)
+ DVI->eraseFromParent();
+ }
+}
+
+/// Fix up the debug info in the old and new functions by pointing line
+/// locations and debug intrinsics to the new subprogram scope, and by deleting
+/// intrinsics which point to values outside of the new function.
+static void fixupDebugInfoPostExtraction(Function &OldFunc, Function &NewFunc,
+ CallInst &TheCall) {
+ DISubprogram *OldSP = OldFunc.getSubprogram();
+ LLVMContext &Ctx = OldFunc.getContext();
+
+ if (!OldSP) {
+ // Erase any debug info the new function contains.
+ stripDebugInfo(NewFunc);
+ // Make sure the old function doesn't contain any non-local metadata refs.
+ eraseDebugIntrinsicsWithNonLocalRefs(NewFunc);
+ return;
+ }
+
+ // Create a subprogram for the new function. Leave out a description of the
+ // function arguments, as the parameters don't correspond to anything at the
+ // source level.
+ assert(OldSP->getUnit() && "Missing compile unit for subprogram");
+ DIBuilder DIB(*OldFunc.getParent(), /*AllowUnresolved=*/false,
+ OldSP->getUnit());
+ auto SPType =
+ DIB.createSubroutineType(DIB.getOrCreateTypeArray(std::nullopt));
+ DISubprogram::DISPFlags SPFlags = DISubprogram::SPFlagDefinition |
+ DISubprogram::SPFlagOptimized |
+ DISubprogram::SPFlagLocalToUnit;
+ auto NewSP = DIB.createFunction(
+ OldSP->getUnit(), NewFunc.getName(), NewFunc.getName(), OldSP->getFile(),
+ /*LineNo=*/0, SPType, /*ScopeLine=*/0, DINode::FlagZero, SPFlags);
+ NewFunc.setSubprogram(NewSP);
+
+ // Debug intrinsics in the new function need to be updated in one of two
+ // ways:
+ // 1) They need to be deleted, because they describe a value in the old
+ // function.
+ // 2) They need to point to fresh metadata, e.g. because they currently
+ // point to a variable in the wrong scope.
+ SmallDenseMap<DINode *, DINode *> RemappedMetadata;
+ SmallVector<Instruction *, 4> DebugIntrinsicsToDelete;
+ DenseMap<const MDNode *, MDNode *> Cache;
+ for (Instruction &I : instructions(NewFunc)) {
+ auto *DII = dyn_cast<DbgInfoIntrinsic>(&I);
+ if (!DII)
+ continue;
+
+ // Point the intrinsic to a fresh label within the new function if the
+ // intrinsic was not inlined from some other function.
+ if (auto *DLI = dyn_cast<DbgLabelInst>(&I)) {
+ if (DLI->getDebugLoc().getInlinedAt())
+ continue;
+ DILabel *OldLabel = DLI->getLabel();
+ DINode *&NewLabel = RemappedMetadata[OldLabel];
+ if (!NewLabel) {
+ DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
+ *OldLabel->getScope(), *NewSP, Ctx, Cache);
+ NewLabel = DILabel::get(Ctx, NewScope, OldLabel->getName(),
+ OldLabel->getFile(), OldLabel->getLine());
+ }
+ DLI->setArgOperand(0, MetadataAsValue::get(Ctx, NewLabel));
+ continue;
+ }
+
+ auto IsInvalidLocation = [&NewFunc](Value *Location) {
+ // Location is invalid if it isn't a constant or an instruction, or is an
+ // instruction but isn't in the new function.
+ if (!Location ||
+ (!isa<Constant>(Location) && !isa<Instruction>(Location)))
+ return true;
+ Instruction *LocationInst = dyn_cast<Instruction>(Location);
+ return LocationInst && LocationInst->getFunction() != &NewFunc;
+ };
+
+ auto *DVI = cast<DbgVariableIntrinsic>(DII);
+ // If any of the used locations are invalid, delete the intrinsic.
+ if (any_of(DVI->location_ops(), IsInvalidLocation)) {
+ DebugIntrinsicsToDelete.push_back(DVI);
+ continue;
+ }
+ // If the variable was in the scope of the old function, i.e. it was not
+ // inlined, point the intrinsic to a fresh variable within the new function.
+ if (!DVI->getDebugLoc().getInlinedAt()) {
+ DILocalVariable *OldVar = DVI->getVariable();
+ DINode *&NewVar = RemappedMetadata[OldVar];
+ if (!NewVar) {
+ DILocalScope *NewScope = DILocalScope::cloneScopeForSubprogram(
+ *OldVar->getScope(), *NewSP, Ctx, Cache);
+ NewVar = DIB.createAutoVariable(
+ NewScope, OldVar->getName(), OldVar->getFile(), OldVar->getLine(),
+ OldVar->getType(), /*AlwaysPreserve=*/false, DINode::FlagZero,
+ OldVar->getAlignInBits());
+ }
+ DVI->setVariable(cast<DILocalVariable>(NewVar));
+ }
+ }
+
+ for (auto *DII : DebugIntrinsicsToDelete)
+ DII->eraseFromParent();
+ DIB.finalizeSubprogram(NewSP);
+
+ // Fix up the scope information attached to the line locations in the new
+ // function.
+ for (Instruction &I : instructions(NewFunc)) {
+ if (const DebugLoc &DL = I.getDebugLoc())
+ I.setDebugLoc(
+ DebugLoc::replaceInlinedAtSubprogram(DL, *NewSP, Ctx, Cache));
+
+ // Loop info metadata may contain line locations. Fix them up.
+ auto updateLoopInfoLoc = [&Ctx, &Cache, NewSP](Metadata *MD) -> Metadata * {
+ if (auto *Loc = dyn_cast_or_null<DILocation>(MD))
+ return DebugLoc::replaceInlinedAtSubprogram(Loc, *NewSP, Ctx, Cache);
+ return MD;
+ };
+ updateLoopMetadataDebugLocations(I, updateLoopInfoLoc);
+ }
+ if (!TheCall.getDebugLoc())
+ TheCall.setDebugLoc(DILocation::get(Ctx, 0, 0, OldSP));
+
+ eraseDebugIntrinsicsWithNonLocalRefs(NewFunc);
+}
+
+Function *
+CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC) {
+ ValueSet Inputs, Outputs;
+ return extractCodeRegion(CEAC, Inputs, Outputs);
+}
+
+Function *
+CodeExtractor::extractCodeRegion(const CodeExtractorAnalysisCache &CEAC,
+ ValueSet &inputs, ValueSet &outputs) {
+ if (!isEligible())
+ return nullptr;
+
+ // Assumption: this is a single-entry code region, and the header is the first
+ // block in the region.
+ BasicBlock *header = *Blocks.begin();
+ Function *oldFunction = header->getParent();
+
+ // Calculate the entry frequency of the new function before we change the root
+ // block.
+ BlockFrequency EntryFreq;
+ if (BFI) {
+ assert(BPI && "Both BPI and BFI are required to preserve profile info");
+ for (BasicBlock *Pred : predecessors(header)) {
+ if (Blocks.count(Pred))
+ continue;
+ EntryFreq +=
+ BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
+ }
+ }
+
+ // Remove CondGuardInsts that will be moved to the new function from the old
+ // function's assumption cache.
+ for (BasicBlock *Block : Blocks) {
+ for (Instruction &I : llvm::make_early_inc_range(*Block)) {
+ if (auto *CI = dyn_cast<CondGuardInst>(&I)) {
+ if (AC)
+ AC->unregisterAssumption(CI);
+ CI->eraseFromParent();
+ }
+ }
+ }
+
+ // If we have any return instructions in the region, split those blocks so
+ // that the return is not in the region.
+ splitReturnBlocks();
+
+ // Calculate the exit blocks for the extracted region and the total exit
+ // weights for each of those blocks.
+ DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
+ SmallPtrSet<BasicBlock *, 1> ExitBlocks;
+ for (BasicBlock *Block : Blocks) {
+ for (BasicBlock *Succ : successors(Block)) {
+ if (!Blocks.count(Succ)) {
+ // Update the branch weight for this successor.
+ if (BFI) {
+ BlockFrequency &BF = ExitWeights[Succ];
+ BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, Succ);
+ }
+ ExitBlocks.insert(Succ);
+ }
+ }
+ }
+ NumExitBlocks = ExitBlocks.size();
+
+ for (BasicBlock *Block : Blocks) {
+ Instruction *TI = Block->getTerminator();
+ for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
+ if (Blocks.count(TI->getSuccessor(i)))
+ continue;
+ BasicBlock *OldTarget = TI->getSuccessor(i);
+ OldTargets.push_back(OldTarget);
+ }
+ }
+
+ // If we have to split PHI nodes of the entry or exit blocks, do so now.
+ severSplitPHINodesOfEntry(header);
+ severSplitPHINodesOfExits(ExitBlocks);
+
+ // This takes place of the original loop
+ BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
+ "codeRepl", oldFunction,
+ header);
+
+ // The new function needs a root node because other nodes can branch to the
+ // head of the region, but the entry node of a function cannot have preds.
+ BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
+ "newFuncRoot");
+ auto *BranchI = BranchInst::Create(header);
+ // If the original function has debug info, we have to add a debug location
+ // to the new branch instruction from the artificial entry block.
+ // We use the debug location of the first instruction in the extracted
+ // blocks, as there is no other equivalent line in the source code.
+ if (oldFunction->getSubprogram()) {
+ any_of(Blocks, [&BranchI](const BasicBlock *BB) {
+ return any_of(*BB, [&BranchI](const Instruction &I) {
+ if (!I.getDebugLoc())
+ return false;
+ BranchI->setDebugLoc(I.getDebugLoc());
+ return true;
+ });
+ });
+ }
+ BranchI->insertInto(newFuncRoot, newFuncRoot->end());
+
+ ValueSet SinkingCands, HoistingCands;
+ BasicBlock *CommonExit = nullptr;
+ findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+ assert(HoistingCands.empty() || CommonExit);
+
+ // Find inputs to, outputs from the code region.
+ findInputsOutputs(inputs, outputs, SinkingCands);
+
+ // Now sink all instructions which only have non-phi uses inside the region.
+ // Group the allocas at the start of the block, so that any bitcast uses of
+ // the allocas are well-defined.
+ AllocaInst *FirstSunkAlloca = nullptr;
+ for (auto *II : SinkingCands) {
+ if (auto *AI = dyn_cast<AllocaInst>(II)) {
+ AI->moveBefore(*newFuncRoot, newFuncRoot->getFirstInsertionPt());
+ if (!FirstSunkAlloca)
+ FirstSunkAlloca = AI;
+ }
+ }
+ assert((SinkingCands.empty() || FirstSunkAlloca) &&
+ "Did not expect a sink candidate without any allocas");
+ for (auto *II : SinkingCands) {
+ if (!isa<AllocaInst>(II)) {
+ cast<Instruction>(II)->moveAfter(FirstSunkAlloca);
+ }
+ }
+
+ if (!HoistingCands.empty()) {
+ auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
+ Instruction *TI = HoistToBlock->getTerminator();
+ for (auto *II : HoistingCands)
+ cast<Instruction>(II)->moveBefore(TI);
+ }
+
+ // Collect objects which are inputs to the extraction region and also
+ // referenced by lifetime start markers within it. The effects of these
+ // markers must be replicated in the calling function to prevent the stack
+ // coloring pass from merging slots which store input objects.
+ ValueSet LifetimesStart;
+ eraseLifetimeMarkersOnInputs(Blocks, SinkingCands, LifetimesStart);
+
+ // Construct new function based on inputs/outputs & add allocas for all defs.
+ Function *newFunction =
+ constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
+ oldFunction, oldFunction->getParent());
+
+ // Update the entry count of the function.
+ if (BFI) {
+ auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
+ if (Count)
+ newFunction->setEntryCount(
+ ProfileCount(*Count, Function::PCT_Real)); // FIXME
+ BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
+ }
+
+ CallInst *TheCall =
+ emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
+
+ moveCodeToFunction(newFunction);
+
+ // Replicate the effects of any lifetime start/end markers which referenced
+ // input objects in the extraction region by placing markers around the call.
+ insertLifetimeMarkersSurroundingCall(
+ oldFunction->getParent(), LifetimesStart.getArrayRef(), {}, TheCall);
+
+ // Propagate personality info to the new function if there is one.
+ if (oldFunction->hasPersonalityFn())
+ newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
+
+ // Update the branch weights for the exit block.
+ if (BFI && NumExitBlocks > 1)
+ calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
+
+ // Loop over all of the PHI nodes in the header and exit blocks, and change
+ // any references to the old incoming edge to be the new incoming edge.
+ for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
+ PHINode *PN = cast<PHINode>(I);
+ for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
+ if (!Blocks.count(PN->getIncomingBlock(i)))
+ PN->setIncomingBlock(i, newFuncRoot);
+ }
+
+ for (BasicBlock *ExitBB : ExitBlocks)
+ for (PHINode &PN : ExitBB->phis()) {
+ Value *IncomingCodeReplacerVal = nullptr;
+ for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
+ // Ignore incoming values from outside of the extracted region.
+ if (!Blocks.count(PN.getIncomingBlock(i)))
+ continue;
+
+ // Ensure that there is only one incoming value from codeReplacer.
+ if (!IncomingCodeReplacerVal) {
+ PN.setIncomingBlock(i, codeReplacer);
+ IncomingCodeReplacerVal = PN.getIncomingValue(i);
+ } else
+ assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
+ "PHI has two incompatbile incoming values from codeRepl");
+ }
+ }
+
+ fixupDebugInfoPostExtraction(*oldFunction, *newFunction, *TheCall);
+
+ // Mark the new function `noreturn` if applicable. Terminators which resume
+ // exception propagation are treated as returning instructions. This is to
+ // avoid inserting traps after calls to outlined functions which unwind.
+ bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) {
+ const Instruction *Term = BB.getTerminator();
+ return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
+ });
+ if (doesNotReturn)
+ newFunction->setDoesNotReturn();
+
+ LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
+ newFunction->dump();
+ report_fatal_error("verification of newFunction failed!");
+ });
+ LLVM_DEBUG(if (verifyFunction(*oldFunction))
+ report_fatal_error("verification of oldFunction failed!"));
+ LLVM_DEBUG(if (AC && verifyAssumptionCache(*oldFunction, *newFunction, AC))
+ report_fatal_error("Stale Asumption cache for old Function!"));
+ return newFunction;
+}
+
+bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
+ const Function &NewFunc,
+ AssumptionCache *AC) {
+ for (auto AssumeVH : AC->assumptions()) {
+ auto *I = dyn_cast_or_null<CondGuardInst>(AssumeVH);
+ if (!I)
+ continue;
+
+ // There shouldn't be any llvm.assume intrinsics in the new function.
+ if (I->getFunction() != &OldFunc)
+ return true;
+
+ // There shouldn't be any stale affected values in the assumption cache
+ // that were previously in the old function, but that have now been moved
+ // to the new function.
+ for (auto AffectedValVH : AC->assumptionsFor(I->getOperand(0))) {
+ auto *AffectedCI = dyn_cast_or_null<CondGuardInst>(AffectedValVH);
+ if (!AffectedCI)
+ continue;
+ if (AffectedCI->getFunction() != &OldFunc)
+ return true;
+ auto *AssumedInst = cast<Instruction>(AffectedCI->getOperand(0));
+ if (AssumedInst->getFunction() != &OldFunc)
+ return true;
+ }
+ }
+ return false;
+}
+
+void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
+ ExcludeArgsFromAggregate.insert(Arg);
+}