ArgumentPromotion.cpp [plain text]
#define DEBUG_TYPE "argpromotion"
#include "llvm/Transforms/IPO.h"
#include "llvm/Constants.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Module.h"
#include "llvm/CallGraphSCCPass.h"
#include "llvm/Instructions.h"
#include "llvm/LLVMContext.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/CallGraph.h"
#include "llvm/Support/CallSite.h"
#include "llvm/Support/CFG.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/StringExtras.h"
#include <set>
using namespace llvm;
STATISTIC(NumArgumentsPromoted , "Number of pointer arguments promoted");
STATISTIC(NumAggregatesPromoted, "Number of aggregate arguments promoted");
STATISTIC(NumByValArgsPromoted , "Number of byval arguments promoted");
STATISTIC(NumArgumentsDead , "Number of dead pointer args eliminated");
namespace {
struct ArgPromotion : public CallGraphSCCPass {
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
AU.addRequired<AliasAnalysis>();
CallGraphSCCPass::getAnalysisUsage(AU);
}
virtual bool runOnSCC(CallGraphSCC &SCC);
static char ID; explicit ArgPromotion(unsigned maxElements = 3)
: CallGraphSCCPass(ID), maxElements(maxElements) {
initializeArgPromotionPass(*PassRegistry::getPassRegistry());
}
typedef std::vector<uint64_t> IndicesVector;
private:
CallGraphNode *PromoteArguments(CallGraphNode *CGN);
bool isSafeToPromoteArgument(Argument *Arg, bool isByVal) const;
CallGraphNode *DoPromotion(Function *F,
SmallPtrSet<Argument*, 8> &ArgsToPromote,
SmallPtrSet<Argument*, 8> &ByValArgsToTransform);
unsigned maxElements;
};
}
char ArgPromotion::ID = 0;
INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion",
"Promote 'by reference' arguments to scalars", false, false)
INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
INITIALIZE_AG_DEPENDENCY(CallGraph)
INITIALIZE_PASS_END(ArgPromotion, "argpromotion",
"Promote 'by reference' arguments to scalars", false, false)
Pass *llvm::createArgumentPromotionPass(unsigned maxElements) {
return new ArgPromotion(maxElements);
}
bool ArgPromotion::runOnSCC(CallGraphSCC &SCC) {
bool Changed = false, LocalChange;
do { LocalChange = false;
for (CallGraphSCC::iterator I = SCC.begin(), E = SCC.end(); I != E; ++I) {
if (CallGraphNode *CGN = PromoteArguments(*I)) {
LocalChange = true;
SCC.ReplaceNode(*I, CGN);
}
}
Changed |= LocalChange; } while (LocalChange);
return Changed;
}
CallGraphNode *ArgPromotion::PromoteArguments(CallGraphNode *CGN) {
Function *F = CGN->getFunction();
if (!F || !F->hasLocalLinkage()) return 0;
SmallVector<std::pair<Argument*, unsigned>, 16> PointerArgs;
unsigned ArgNo = 0;
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
I != E; ++I, ++ArgNo)
if (I->getType()->isPointerTy())
PointerArgs.push_back(std::pair<Argument*, unsigned>(I, ArgNo));
if (PointerArgs.empty()) return 0;
bool isSelfRecursive = false;
for (Value::use_iterator UI = F->use_begin(), E = F->use_end();
UI != E; ++UI) {
CallSite CS(*UI);
if (CS.getInstruction() == 0 || !CS.isCallee(UI)) return 0;
if (CS.getInstruction()->getParent()->getParent() == F)
isSelfRecursive = true;
}
SmallPtrSet<Argument*, 8> ArgsToPromote;
SmallPtrSet<Argument*, 8> ByValArgsToTransform;
for (unsigned i = 0; i != PointerArgs.size(); ++i) {
bool isByVal = F->paramHasAttr(PointerArgs[i].second+1, Attribute::ByVal);
Argument *PtrArg = PointerArgs[i].first;
Type *AgTy = cast<PointerType>(PtrArg->getType())->getElementType();
if (isByVal) {
if (StructType *STy = dyn_cast<StructType>(AgTy)) {
if (maxElements > 0 && STy->getNumElements() > maxElements) {
DEBUG(dbgs() << "argpromotion disable promoting argument '"
<< PtrArg->getName() << "' because it would require adding more"
<< " than " << maxElements << " arguments to the function.\n");
continue;
}
bool AllSimple = true;
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
if (!STy->getElementType(i)->isSingleValueType()) {
AllSimple = false;
break;
}
}
if (AllSimple) {
ByValArgsToTransform.insert(PtrArg);
continue;
}
}
}
if (isSelfRecursive) {
if (StructType *STy = dyn_cast<StructType>(AgTy)) {
bool RecursiveType = false;
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
if (STy->getElementType(i) == PtrArg->getType()) {
RecursiveType = true;
break;
}
}
if (RecursiveType)
continue;
}
}
if (isSafeToPromoteArgument(PtrArg, isByVal))
ArgsToPromote.insert(PtrArg);
}
if (ArgsToPromote.empty() && ByValArgsToTransform.empty())
return 0;
return DoPromotion(F, ArgsToPromote, ByValArgsToTransform);
}
static bool AllCallersPassInValidPointerForArgument(Argument *Arg) {
Function *Callee = Arg->getParent();
unsigned ArgNo = std::distance(Callee->arg_begin(),
Function::arg_iterator(Arg));
for (Value::use_iterator UI = Callee->use_begin(), E = Callee->use_end();
UI != E; ++UI) {
CallSite CS(*UI);
assert(CS && "Should only have direct calls!");
if (!CS.getArgument(ArgNo)->isDereferenceablePointer())
return false;
}
return true;
}
static bool IsPrefix(const ArgPromotion::IndicesVector &Prefix,
const ArgPromotion::IndicesVector &Longer) {
if (Prefix.size() > Longer.size())
return false;
for (unsigned i = 0, e = Prefix.size(); i != e; ++i)
if (Prefix[i] != Longer[i])
return false;
return true;
}
static bool PrefixIn(const ArgPromotion::IndicesVector &Indices,
std::set<ArgPromotion::IndicesVector> &Set) {
std::set<ArgPromotion::IndicesVector>::iterator Low;
Low = Set.upper_bound(Indices);
if (Low != Set.begin())
Low--;
return Low != Set.end() && IsPrefix(*Low, Indices);
}
static void MarkIndicesSafe(const ArgPromotion::IndicesVector &ToMark,
std::set<ArgPromotion::IndicesVector> &Safe) {
std::set<ArgPromotion::IndicesVector>::iterator Low;
Low = Safe.upper_bound(ToMark);
if (Low != Safe.begin())
Low--;
if (Low != Safe.end()) {
if (IsPrefix(*Low, ToMark))
return;
++Low;
}
Low = Safe.insert(Low, ToMark);
++Low;
std::set<ArgPromotion::IndicesVector>::iterator End = Safe.end();
while (Low != End && IsPrefix(ToMark, *Low)) {
std::set<ArgPromotion::IndicesVector>::iterator Remove = Low;
++Low;
Safe.erase(Remove);
}
}
bool ArgPromotion::isSafeToPromoteArgument(Argument *Arg, bool isByVal) const {
typedef std::set<IndicesVector> GEPIndicesSet;
if (Arg->use_empty())
return true;
GEPIndicesSet SafeToUnconditionallyLoad;
GEPIndicesSet ToPromote;
if (isByVal || AllCallersPassInValidPointerForArgument(Arg))
SafeToUnconditionallyLoad.insert(IndicesVector(1, 0));
BasicBlock *EntryBlock = Arg->getParent()->begin();
IndicesVector Indices;
for (BasicBlock::iterator I = EntryBlock->begin(), E = EntryBlock->end();
I != E; ++I)
if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
Value *V = LI->getPointerOperand();
if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(V)) {
V = GEP->getPointerOperand();
if (V == Arg) {
Indices.reserve(GEP->getNumIndices());
for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end();
II != IE; ++II)
if (ConstantInt *CI = dyn_cast<ConstantInt>(*II))
Indices.push_back(CI->getSExtValue());
else
return false;
MarkIndicesSafe(Indices, SafeToUnconditionallyLoad);
Indices.clear();
}
} else if (V == Arg) {
MarkIndicesSafe(IndicesVector(1, 0), SafeToUnconditionallyLoad);
}
}
SmallVector<LoadInst*, 16> Loads;
IndicesVector Operands;
for (Value::use_iterator UI = Arg->use_begin(), E = Arg->use_end();
UI != E; ++UI) {
User *U = *UI;
Operands.clear();
if (LoadInst *LI = dyn_cast<LoadInst>(U)) {
if (!LI->isSimple()) return false;
Loads.push_back(LI);
Operands.push_back(0);
} else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
if (GEP->use_empty()) {
getAnalysis<AliasAnalysis>().deleteValue(GEP);
GEP->eraseFromParent();
return isSafeToPromoteArgument(Arg, isByVal);
}
for (User::op_iterator i = GEP->idx_begin(), e = GEP->idx_end();
i != e; ++i)
if (ConstantInt *C = dyn_cast<ConstantInt>(*i))
Operands.push_back(C->getSExtValue());
else
return false;
for (Value::use_iterator UI = GEP->use_begin(), E = GEP->use_end();
UI != E; ++UI)
if (LoadInst *LI = dyn_cast<LoadInst>(*UI)) {
if (!LI->isSimple()) return false;
Loads.push_back(LI);
} else {
return false;
}
} else {
return false; }
if (!PrefixIn(Operands, SafeToUnconditionallyLoad))
return false;
if (ToPromote.find(Operands) == ToPromote.end()) {
if (maxElements > 0 && ToPromote.size() == maxElements) {
DEBUG(dbgs() << "argpromotion not promoting argument '"
<< Arg->getName() << "' because it would require adding more "
<< "than " << maxElements << " arguments to the function.\n");
return false;
}
ToPromote.insert(Operands);
}
}
if (Loads.empty()) return true;
SmallPtrSet<BasicBlock*, 16> TranspBlocks;
AliasAnalysis &AA = getAnalysis<AliasAnalysis>();
for (unsigned i = 0, e = Loads.size(); i != e; ++i) {
LoadInst *Load = Loads[i];
BasicBlock *BB = Load->getParent();
AliasAnalysis::Location Loc = AA.getLocation(Load);
if (AA.canInstructionRangeModify(BB->front(), *Load, Loc))
return false;
for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) {
BasicBlock *P = *PI;
for (idf_ext_iterator<BasicBlock*, SmallPtrSet<BasicBlock*, 16> >
I = idf_ext_begin(P, TranspBlocks),
E = idf_ext_end(P, TranspBlocks); I != E; ++I)
if (AA.canBasicBlockModify(**I, Loc))
return false;
}
}
return true;
}
CallGraphNode *ArgPromotion::DoPromotion(Function *F,
SmallPtrSet<Argument*, 8> &ArgsToPromote,
SmallPtrSet<Argument*, 8> &ByValArgsToTransform) {
FunctionType *FTy = F->getFunctionType();
std::vector<Type*> Params;
typedef std::set<IndicesVector> ScalarizeTable;
std::map<Argument*, ScalarizeTable> ScalarizedElements;
std::map<IndicesVector, LoadInst*> OriginalLoads;
SmallVector<AttributeWithIndex, 8> AttributesVec;
const AttrListPtr &PAL = F->getAttributes();
if (Attributes attrs = PAL.getRetAttributes())
AttributesVec.push_back(AttributeWithIndex::get(0, attrs));
unsigned ArgIndex = 1;
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); I != E;
++I, ++ArgIndex) {
if (ByValArgsToTransform.count(I)) {
Type *AgTy = cast<PointerType>(I->getType())->getElementType();
StructType *STy = cast<StructType>(AgTy);
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i)
Params.push_back(STy->getElementType(i));
++NumByValArgsPromoted;
} else if (!ArgsToPromote.count(I)) {
Params.push_back(I->getType());
if (Attributes attrs = PAL.getParamAttributes(ArgIndex))
AttributesVec.push_back(AttributeWithIndex::get(Params.size(), attrs));
} else if (I->use_empty()) {
++NumArgumentsDead;
} else {
ScalarizeTable &ArgIndices = ScalarizedElements[I];
for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); UI != E;
++UI) {
Instruction *User = cast<Instruction>(*UI);
assert(isa<LoadInst>(User) || isa<GetElementPtrInst>(User));
IndicesVector Indices;
Indices.reserve(User->getNumOperands() - 1);
for (User::op_iterator II = User->op_begin() + 1, IE = User->op_end();
II != IE; ++II)
Indices.push_back(cast<ConstantInt>(*II)->getSExtValue());
if (Indices.size() == 1 && Indices.front() == 0)
Indices.clear();
ArgIndices.insert(Indices);
LoadInst *OrigLoad;
if (LoadInst *L = dyn_cast<LoadInst>(User))
OrigLoad = L;
else
OrigLoad = cast<LoadInst>(User->use_back());
OriginalLoads[Indices] = OrigLoad;
}
for (ScalarizeTable::iterator SI = ArgIndices.begin(),
E = ArgIndices.end(); SI != E; ++SI) {
Params.push_back(GetElementPtrInst::getIndexedType(I->getType(), *SI));
assert(Params.back());
}
if (ArgIndices.size() == 1 && ArgIndices.begin()->empty())
++NumArgumentsPromoted;
else
++NumAggregatesPromoted;
}
}
if (Attributes attrs = PAL.getFnAttributes())
AttributesVec.push_back(AttributeWithIndex::get(~0, attrs));
Type *RetTy = FTy->getReturnType();
bool ExtraArgHack = false;
if (Params.empty() && FTy->isVarArg()) {
ExtraArgHack = true;
Params.push_back(Type::getInt32Ty(F->getContext()));
}
FunctionType *NFTy = FunctionType::get(RetTy, Params, FTy->isVarArg());
Function *NF = Function::Create(NFTy, F->getLinkage(), F->getName());
NF->copyAttributesFrom(F);
DEBUG(dbgs() << "ARG PROMOTION: Promoting to:" << *NF << "\n"
<< "From: " << *F);
NF->setAttributes(AttrListPtr::get(AttributesVec.begin(),
AttributesVec.end()));
AttributesVec.clear();
F->getParent()->getFunctionList().insert(F, NF);
NF->takeName(F);
AliasAnalysis &AA = getAnalysis<AliasAnalysis>();
CallGraph &CG = getAnalysis<CallGraph>();
CallGraphNode *NF_CGN = CG.getOrInsertFunction(NF);
SmallVector<Value*, 16> Args;
while (!F->use_empty()) {
CallSite CS(F->use_back());
assert(CS.getCalledFunction() == F);
Instruction *Call = CS.getInstruction();
const AttrListPtr &CallPAL = CS.getAttributes();
if (Attributes attrs = CallPAL.getRetAttributes())
AttributesVec.push_back(AttributeWithIndex::get(0, attrs));
CallSite::arg_iterator AI = CS.arg_begin();
ArgIndex = 1;
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end();
I != E; ++I, ++AI, ++ArgIndex)
if (!ArgsToPromote.count(I) && !ByValArgsToTransform.count(I)) {
Args.push_back(*AI);
if (Attributes Attrs = CallPAL.getParamAttributes(ArgIndex))
AttributesVec.push_back(AttributeWithIndex::get(Args.size(), Attrs));
} else if (ByValArgsToTransform.count(I)) {
Type *AgTy = cast<PointerType>(I->getType())->getElementType();
StructType *STy = cast<StructType>(AgTy);
Value *Idxs[2] = {
ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), 0 };
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
Value *Idx = GetElementPtrInst::Create(*AI, Idxs,
(*AI)->getName()+"."+utostr(i),
Call);
Args.push_back(new LoadInst(Idx, Idx->getName()+".val", Call));
}
} else if (!I->use_empty()) {
ScalarizeTable &ArgIndices = ScalarizedElements[I];
std::vector<Value*> Ops;
for (ScalarizeTable::iterator SI = ArgIndices.begin(),
E = ArgIndices.end(); SI != E; ++SI) {
Value *V = *AI;
LoadInst *OrigLoad = OriginalLoads[*SI];
if (!SI->empty()) {
Ops.reserve(SI->size());
Type *ElTy = V->getType();
for (IndicesVector::const_iterator II = SI->begin(),
IE = SI->end(); II != IE; ++II) {
Type *IdxTy = (ElTy->isStructTy() ?
Type::getInt32Ty(F->getContext()) :
Type::getInt64Ty(F->getContext()));
Ops.push_back(ConstantInt::get(IdxTy, *II));
ElTy = cast<CompositeType>(ElTy)->getTypeAtIndex(*II);
}
V = GetElementPtrInst::Create(V, Ops, V->getName()+".idx", Call);
Ops.clear();
AA.copyValue(OrigLoad->getOperand(0), V);
}
LoadInst *newLoad = new LoadInst(V, V->getName()+".val", Call);
newLoad->setAlignment(OrigLoad->getAlignment());
newLoad->setMetadata(LLVMContext::MD_tbaa,
OrigLoad->getMetadata(LLVMContext::MD_tbaa));
Args.push_back(newLoad);
AA.copyValue(OrigLoad, Args.back());
}
}
if (ExtraArgHack)
Args.push_back(Constant::getNullValue(Type::getInt32Ty(F->getContext())));
for (; AI != CS.arg_end(); ++AI, ++ArgIndex) {
Args.push_back(*AI);
if (Attributes Attrs = CallPAL.getParamAttributes(ArgIndex))
AttributesVec.push_back(AttributeWithIndex::get(Args.size(), Attrs));
}
if (Attributes attrs = CallPAL.getFnAttributes())
AttributesVec.push_back(AttributeWithIndex::get(~0, attrs));
Instruction *New;
if (InvokeInst *II = dyn_cast<InvokeInst>(Call)) {
New = InvokeInst::Create(NF, II->getNormalDest(), II->getUnwindDest(),
Args, "", Call);
cast<InvokeInst>(New)->setCallingConv(CS.getCallingConv());
cast<InvokeInst>(New)->setAttributes(AttrListPtr::get(AttributesVec.begin(),
AttributesVec.end()));
} else {
New = CallInst::Create(NF, Args, "", Call);
cast<CallInst>(New)->setCallingConv(CS.getCallingConv());
cast<CallInst>(New)->setAttributes(AttrListPtr::get(AttributesVec.begin(),
AttributesVec.end()));
if (cast<CallInst>(Call)->isTailCall())
cast<CallInst>(New)->setTailCall();
}
Args.clear();
AttributesVec.clear();
AA.replaceWithNewValue(Call, New);
CallGraphNode *CalleeNode = CG[Call->getParent()->getParent()];
CalleeNode->replaceCallEdge(Call, New, NF_CGN);
if (!Call->use_empty()) {
Call->replaceAllUsesWith(New);
New->takeName(Call);
}
Call->eraseFromParent();
}
NF->getBasicBlockList().splice(NF->begin(), F->getBasicBlockList());
for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(),
I2 = NF->arg_begin(); I != E; ++I) {
if (!ArgsToPromote.count(I) && !ByValArgsToTransform.count(I)) {
I->replaceAllUsesWith(I2);
I2->takeName(I);
AA.replaceWithNewValue(I, I2);
++I2;
continue;
}
if (ByValArgsToTransform.count(I)) {
Instruction *InsertPt = NF->begin()->begin();
Type *AgTy = cast<PointerType>(I->getType())->getElementType();
Value *TheAlloca = new AllocaInst(AgTy, 0, "", InsertPt);
StructType *STy = cast<StructType>(AgTy);
Value *Idxs[2] = {
ConstantInt::get(Type::getInt32Ty(F->getContext()), 0), 0 };
for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
Idxs[1] = ConstantInt::get(Type::getInt32Ty(F->getContext()), i);
Value *Idx =
GetElementPtrInst::Create(TheAlloca, Idxs,
TheAlloca->getName()+"."+Twine(i),
InsertPt);
I2->setName(I->getName()+"."+Twine(i));
new StoreInst(I2++, Idx, InsertPt);
}
I->replaceAllUsesWith(TheAlloca);
TheAlloca->takeName(I);
AA.replaceWithNewValue(I, TheAlloca);
continue;
}
if (I->use_empty()) {
AA.deleteValue(I);
continue;
}
ScalarizeTable &ArgIndices = ScalarizedElements[I];
while (!I->use_empty()) {
if (LoadInst *LI = dyn_cast<LoadInst>(I->use_back())) {
assert(ArgIndices.begin()->empty() &&
"Load element should sort to front!");
I2->setName(I->getName()+".val");
LI->replaceAllUsesWith(I2);
AA.replaceWithNewValue(LI, I2);
LI->eraseFromParent();
DEBUG(dbgs() << "*** Promoted load of argument '" << I->getName()
<< "' in function '" << F->getName() << "'\n");
} else {
GetElementPtrInst *GEP = cast<GetElementPtrInst>(I->use_back());
IndicesVector Operands;
Operands.reserve(GEP->getNumIndices());
for (User::op_iterator II = GEP->idx_begin(), IE = GEP->idx_end();
II != IE; ++II)
Operands.push_back(cast<ConstantInt>(*II)->getSExtValue());
if (Operands.size() == 1 && Operands.front() == 0)
Operands.clear();
Function::arg_iterator TheArg = I2;
for (ScalarizeTable::iterator It = ArgIndices.begin();
*It != Operands; ++It, ++TheArg) {
assert(It != ArgIndices.end() && "GEP not handled??");
}
std::string NewName = I->getName();
for (unsigned i = 0, e = Operands.size(); i != e; ++i) {
NewName += "." + utostr(Operands[i]);
}
NewName += ".val";
TheArg->setName(NewName);
DEBUG(dbgs() << "*** Promoted agg argument '" << TheArg->getName()
<< "' of function '" << NF->getName() << "'\n");
while (!GEP->use_empty()) {
LoadInst *L = cast<LoadInst>(GEP->use_back());
L->replaceAllUsesWith(TheArg);
AA.replaceWithNewValue(L, TheArg);
L->eraseFromParent();
}
AA.deleteValue(GEP);
GEP->eraseFromParent();
}
}
for (unsigned i = 0, e = ArgIndices.size(); i != e; ++i)
++I2;
}
if (ExtraArgHack)
AA.copyValue(Constant::getNullValue(Type::getInt32Ty(F->getContext())),
NF->arg_begin());
AA.replaceWithNewValue(F, NF);
NF_CGN->stealCalledFunctionsFrom(CG[F]);
CallGraphNode *CGN = CG[F];
if (CGN->getNumReferences() == 0)
delete CG.removeFunctionFromModule(CGN);
else
F->setLinkage(Function::ExternalLinkage);
return NF_CGN;
}