TailRecursionElimination.cpp [plain text]
#include "llvm/Transforms/Scalar.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/CaptureTracking.h"
#include "llvm/Analysis/CFG.h"
#include "llvm/Analysis/InlineCost.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/CallSite.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
#define DEBUG_TYPE "tailcallelim"
STATISTIC(NumEliminated, "Number of tail calls removed");
STATISTIC(NumRetDuped, "Number of return duplicated");
STATISTIC(NumAccumAdded, "Number of accumulators introduced");
namespace {
struct TailCallElim : public FunctionPass {
const TargetTransformInfo *TTI;
static char ID; TailCallElim() : FunctionPass(ID) {
initializeTailCallElimPass(*PassRegistry::getPassRegistry());
}
void getAnalysisUsage(AnalysisUsage &AU) const override;
bool runOnFunction(Function &F) override;
private:
bool runTRE(Function &F);
bool markTails(Function &F, bool &AllCallsAreTailCalls);
CallInst *FindTRECandidate(Instruction *I,
bool CannotTailCallElimCallsMarkedTail);
bool EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
BasicBlock *&OldEntry,
bool &TailCallsAreMarkedTail,
SmallVectorImpl<PHINode *> &ArgumentPHIs,
bool CannotTailCallElimCallsMarkedTail);
bool FoldReturnAndProcessPred(BasicBlock *BB,
ReturnInst *Ret, BasicBlock *&OldEntry,
bool &TailCallsAreMarkedTail,
SmallVectorImpl<PHINode *> &ArgumentPHIs,
bool CannotTailCallElimCallsMarkedTail);
bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry,
bool &TailCallsAreMarkedTail,
SmallVectorImpl<PHINode *> &ArgumentPHIs,
bool CannotTailCallElimCallsMarkedTail);
bool CanMoveAboveCall(Instruction *I, CallInst *CI);
Value *CanTransformAccumulatorRecursion(Instruction *I, CallInst *CI);
};
}
char TailCallElim::ID = 0;
INITIALIZE_PASS_BEGIN(TailCallElim, "tailcallelim",
"Tail Call Elimination", false, false)
INITIALIZE_AG_DEPENDENCY(TargetTransformInfo)
INITIALIZE_PASS_END(TailCallElim, "tailcallelim",
"Tail Call Elimination", false, false)
FunctionPass *llvm::createTailCallEliminationPass() {
return new TailCallElim();
}
void TailCallElim::getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<TargetTransformInfo>();
}
static bool CanTRE(Function &F) {
for (auto &BB : F) {
for (auto &I : BB) {
if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) {
if (!AI->isStaticAlloca())
return false;
}
}
}
return true;
}
bool TailCallElim::runOnFunction(Function &F) {
if (skipOptnoneFunction(F))
return false;
bool AllCallsAreTailCalls = false;
bool Modified = markTails(F, AllCallsAreTailCalls);
if (AllCallsAreTailCalls)
Modified |= runTRE(F);
return Modified;
}
namespace {
struct AllocaDerivedValueTracker {
void walk(Value *Root) {
SmallVector<Use *, 32> Worklist;
SmallPtrSet<Use *, 32> Visited;
auto AddUsesToWorklist = [&](Value *V) {
for (auto &U : V->uses()) {
if (!Visited.insert(&U))
continue;
Worklist.push_back(&U);
}
};
AddUsesToWorklist(Root);
while (!Worklist.empty()) {
Use *U = Worklist.pop_back_val();
Instruction *I = cast<Instruction>(U->getUser());
switch (I->getOpcode()) {
case Instruction::Call:
case Instruction::Invoke: {
CallSite CS(I);
bool IsNocapture = !CS.isCallee(U) &&
CS.doesNotCapture(CS.getArgumentNo(U));
callUsesLocalStack(CS, IsNocapture);
if (IsNocapture) {
continue;
}
break;
}
case Instruction::Load: {
continue;
}
case Instruction::Store: {
if (U->getOperandNo() == 0)
EscapePoints.insert(I);
continue; }
case Instruction::BitCast:
case Instruction::GetElementPtr:
case Instruction::PHI:
case Instruction::Select:
case Instruction::AddrSpaceCast:
break;
default:
EscapePoints.insert(I);
break;
}
AddUsesToWorklist(I);
}
}
void callUsesLocalStack(CallSite CS, bool IsNocapture) {
AllocaUsers.insert(CS.getInstruction());
if (IsNocapture)
return;
if (!CS.onlyReadsMemory())
EscapePoints.insert(CS.getInstruction());
}
SmallPtrSet<Instruction *, 32> AllocaUsers;
SmallPtrSet<Instruction *, 32> EscapePoints;
};
}
bool TailCallElim::markTails(Function &F, bool &AllCallsAreTailCalls) {
if (F.callsFunctionThatReturnsTwice())
return false;
AllCallsAreTailCalls = true;
AllocaDerivedValueTracker Tracker;
for (Argument &Arg : F.args()) {
if (Arg.hasByValAttr())
Tracker.walk(&Arg);
}
for (auto &BB : F) {
for (auto &I : BB)
if (AllocaInst *AI = dyn_cast<AllocaInst>(&I))
Tracker.walk(AI);
}
bool Modified = false;
enum VisitType {
UNVISITED,
UNESCAPED,
ESCAPED
};
DenseMap<BasicBlock *, VisitType> Visited;
SmallVector<BasicBlock *, 32> WorklistUnescaped, WorklistEscaped;
SmallVector<CallInst *, 32> DeferredTails;
BasicBlock *BB = &F.getEntryBlock();
VisitType Escaped = UNESCAPED;
do {
for (auto &I : *BB) {
if (Tracker.EscapePoints.count(&I))
Escaped = ESCAPED;
CallInst *CI = dyn_cast<CallInst>(&I);
if (!CI || CI->isTailCall())
continue;
if (CI->doesNotAccessMemory()) {
bool SafeToTail = true;
for (auto &Arg : CI->arg_operands()) {
if (isa<Constant>(Arg.getUser()))
continue;
if (Argument *A = dyn_cast<Argument>(Arg.getUser()))
if (!A->hasByValAttr())
continue;
SafeToTail = false;
break;
}
if (SafeToTail) {
emitOptimizationRemark(
F.getContext(), "tailcallelim", F, CI->getDebugLoc(),
"marked this readnone call a tail call candidate");
CI->setTailCall();
Modified = true;
continue;
}
}
if (Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI)) {
DeferredTails.push_back(CI);
} else {
AllCallsAreTailCalls = false;
}
}
for (auto *SuccBB : make_range(succ_begin(BB), succ_end(BB))) {
auto &State = Visited[SuccBB];
if (State < Escaped) {
State = Escaped;
if (State == ESCAPED)
WorklistEscaped.push_back(SuccBB);
else
WorklistUnescaped.push_back(SuccBB);
}
}
if (!WorklistEscaped.empty()) {
BB = WorklistEscaped.pop_back_val();
Escaped = ESCAPED;
} else {
BB = nullptr;
while (!WorklistUnescaped.empty()) {
auto *NextBB = WorklistUnescaped.pop_back_val();
if (Visited[NextBB] == UNESCAPED) {
BB = NextBB;
Escaped = UNESCAPED;
break;
}
}
}
} while (BB);
for (CallInst *CI : DeferredTails) {
if (Visited[CI->getParent()] != ESCAPED) {
emitOptimizationRemark(F.getContext(), "tailcallelim", F,
CI->getDebugLoc(),
"marked this call a tail call candidate");
CI->setTailCall();
Modified = true;
} else {
AllCallsAreTailCalls = false;
}
}
return Modified;
}
bool TailCallElim::runTRE(Function &F) {
if (F.getFunctionType()->isVarArg()) return false;
TTI = &getAnalysis<TargetTransformInfo>();
BasicBlock *OldEntry = nullptr;
bool TailCallsAreMarkedTail = false;
SmallVector<PHINode*, 8> ArgumentPHIs;
bool MadeChange = false;
bool CanTRETailMarkedCall = CanTRE(F);
SmallVector<BasicBlock*, 8> BBToErase;
for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) {
if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
bool Change = ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
ArgumentPHIs, !CanTRETailMarkedCall);
if (!Change && BB->getFirstNonPHIOrDbg() == Ret) {
Change = FoldReturnAndProcessPred(BB, Ret, OldEntry,
TailCallsAreMarkedTail, ArgumentPHIs,
!CanTRETailMarkedCall);
if (Change && BB->empty())
BBToErase.push_back(BB);
}
MadeChange |= Change;
}
}
for (auto BB: BBToErase)
BB->eraseFromParent();
for (unsigned i = 0, e = ArgumentPHIs.size(); i != e; ++i) {
PHINode *PN = ArgumentPHIs[i];
if (Value *PNV = SimplifyInstruction(PN)) {
PN->replaceAllUsesWith(PNV);
PN->eraseFromParent();
}
}
return MadeChange;
}
bool TailCallElim::CanMoveAboveCall(Instruction *I, CallInst *CI) {
if (I->mayHaveSideEffects()) return false;
if (LoadInst *L = dyn_cast<LoadInst>(I)) {
if (CI->mayHaveSideEffects()) {
if (CI->mayWriteToMemory() ||
!isSafeToLoadUnconditionally(L->getPointerOperand(), L,
L->getAlignment()))
return false;
}
}
for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
if (I->getOperand(i) == CI)
return false;
return true;
}
static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) {
if (isa<Constant>(V)) return true;
if (Argument *Arg = dyn_cast<Argument>(V)) {
unsigned ArgNo = 0;
Function *F = CI->getParent()->getParent();
for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI)
++ArgNo;
if (CI->getArgOperand(ArgNo) == Arg)
return true;
}
if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor())
if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator()))
if (SI->getCondition() == V)
return SI->getDefaultDest() != RI->getParent();
return false;
}
static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) {
Function *F = CI->getParent()->getParent();
Value *ReturnedValue = nullptr;
for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI) {
ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator());
if (RI == nullptr || RI == IgnoreRI) continue;
Value *RetOp = RI->getOperand(0);
if (!isDynamicConstant(RetOp, CI, RI))
return nullptr;
if (ReturnedValue && RetOp != ReturnedValue)
return nullptr; ReturnedValue = RetOp;
}
return ReturnedValue;
}
Value *TailCallElim::CanTransformAccumulatorRecursion(Instruction *I,
CallInst *CI) {
if (!I->isAssociative() || !I->isCommutative()) return nullptr;
assert(I->getNumOperands() == 2 &&
"Associative/commutative operations should have 2 args!");
if ((I->getOperand(0) == CI && I->getOperand(1) == CI) ||
(I->getOperand(0) != CI && I->getOperand(1) != CI))
return nullptr;
if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back()))
return nullptr;
return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI);
}
static Instruction *FirstNonDbg(BasicBlock::iterator I) {
while (isa<DbgInfoIntrinsic>(I))
++I;
return &*I;
}
CallInst*
TailCallElim::FindTRECandidate(Instruction *TI,
bool CannotTailCallElimCallsMarkedTail) {
BasicBlock *BB = TI->getParent();
Function *F = BB->getParent();
if (&BB->front() == TI) return nullptr;
CallInst *CI = nullptr;
BasicBlock::iterator BBI = TI;
while (true) {
CI = dyn_cast<CallInst>(BBI);
if (CI && CI->getCalledFunction() == F)
break;
if (BBI == BB->begin())
return nullptr; --BBI;
}
if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail)
return nullptr;
if (BB == &F->getEntryBlock() &&
FirstNonDbg(BB->front()) == CI &&
FirstNonDbg(std::next(BB->begin())) == TI &&
CI->getCalledFunction() &&
!TTI->isLoweredToCall(CI->getCalledFunction())) {
CallSite::arg_iterator I = CallSite(CI).arg_begin(),
E = CallSite(CI).arg_end();
Function::arg_iterator FI = F->arg_begin(),
FE = F->arg_end();
for (; I != E && FI != FE; ++I, ++FI)
if (*I != &*FI) break;
if (I == E && FI == FE)
return nullptr;
}
return CI;
}
bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
BasicBlock *&OldEntry,
bool &TailCallsAreMarkedTail,
SmallVectorImpl<PHINode *> &ArgumentPHIs,
bool CannotTailCallElimCallsMarkedTail) {
Value *AccumulatorRecursionEliminationInitVal = nullptr;
Instruction *AccumulatorRecursionInstr = nullptr;
BasicBlock::iterator BBI = CI;
for (++BBI; &*BBI != Ret; ++BBI) {
if (CanMoveAboveCall(BBI, CI)) continue;
if ((AccumulatorRecursionEliminationInitVal =
CanTransformAccumulatorRecursion(BBI, CI))) {
AccumulatorRecursionInstr = BBI;
} else {
return false; }
}
if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI &&
!isa<UndefValue>(Ret->getReturnValue()) &&
AccumulatorRecursionEliminationInitVal == nullptr &&
!getCommonReturnValue(nullptr, CI)) {
if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret))
return false; AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI);
if (!AccumulatorRecursionEliminationInitVal)
return false;
}
BasicBlock *BB = Ret->getParent();
Function *F = BB->getParent();
emitOptimizationRemark(F->getContext(), "tailcallelim", *F, CI->getDebugLoc(),
"transforming tail recursion to loop");
if (!OldEntry) {
OldEntry = &F->getEntryBlock();
BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry);
NewEntry->takeName(OldEntry);
OldEntry->setName("tailrecurse");
BranchInst::Create(OldEntry, NewEntry);
TailCallsAreMarkedTail = CI->isTailCall();
if (TailCallsAreMarkedTail)
for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(),
NEBI = NewEntry->begin(); OEBI != E; )
if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++))
if (isa<ConstantInt>(AI->getArraySize()))
AI->moveBefore(NEBI);
Instruction *InsertPos = OldEntry->begin();
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
I != E; ++I) {
PHINode *PN = PHINode::Create(I->getType(), 2,
I->getName() + ".tr", InsertPos);
I->replaceAllUsesWith(PN); PN->addIncoming(I, NewEntry);
ArgumentPHIs.push_back(PN);
}
}
if (TailCallsAreMarkedTail && !CI->isTailCall())
return false;
for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i)
ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB);
if (AccumulatorRecursionEliminationInitVal) {
Instruction *AccRecInstr = AccumulatorRecursionInstr;
pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry);
PHINode *AccPN =
PHINode::Create(AccumulatorRecursionEliminationInitVal->getType(),
std::distance(PB, PE) + 1,
"accumulator.tr", OldEntry->begin());
for (pred_iterator PI = PB; PI != PE; ++PI) {
BasicBlock *P = *PI;
if (P == &F->getEntryBlock())
AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P);
else
AccPN->addIncoming(AccPN, P);
}
if (AccRecInstr) {
AccPN->addIncoming(AccRecInstr, BB);
AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN);
} else {
AccPN->addIncoming(Ret->getReturnValue(), BB);
}
for (Function::iterator BBI = F->begin(), E = F->end(); BBI != E; ++BBI)
if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI->getTerminator()))
RI->setOperand(0, AccPN);
++NumAccumAdded;
}
BranchInst *NewBI = BranchInst::Create(OldEntry, Ret);
NewBI->setDebugLoc(CI->getDebugLoc());
BB->getInstList().erase(Ret); BB->getInstList().erase(CI); ++NumEliminated;
return true;
}
bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB,
ReturnInst *Ret, BasicBlock *&OldEntry,
bool &TailCallsAreMarkedTail,
SmallVectorImpl<PHINode *> &ArgumentPHIs,
bool CannotTailCallElimCallsMarkedTail) {
bool Change = false;
SmallVector<BranchInst*, 8> UncondBranchPreds;
for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) {
BasicBlock *Pred = *PI;
TerminatorInst *PTI = Pred->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(PTI))
if (BI->isUnconditional())
UncondBranchPreds.push_back(BI);
}
while (!UncondBranchPreds.empty()) {
BranchInst *BI = UncondBranchPreds.pop_back_val();
BasicBlock *Pred = BI->getParent();
if (CallInst *CI = FindTRECandidate(BI, CannotTailCallElimCallsMarkedTail)){
DEBUG(dbgs() << "FOLDING: " << *BB
<< "INTO UNCOND BRANCH PRED: " << *Pred);
ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred);
if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB))
BB->getInstList().clear();
EliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail,
ArgumentPHIs,
CannotTailCallElimCallsMarkedTail);
++NumRetDuped;
Change = true;
}
}
return Change;
}
bool
TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
bool &TailCallsAreMarkedTail,
SmallVectorImpl<PHINode *> &ArgumentPHIs,
bool CannotTailCallElimCallsMarkedTail) {
CallInst *CI = FindTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);
if (!CI)
return false;
return EliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail,
ArgumentPHIs,
CannotTailCallElimCallsMarkedTail);
}