ARM64PromoteConstant.cpp [plain text]
#define DEBUG_TYPE "arm64-promote-const"
#include "ARM64.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
using namespace llvm;
static cl::opt<bool> Stress("arm64-stress-promote-const", cl::Hidden,
cl::desc("Promote all vector constants"));
STATISTIC(NumPromoted, "Number of promoted constants");
STATISTIC(NumPromotedUses, "Number of promoted constants uses");
namespace {
class ARM64PromoteConstant : public ModulePass {
public:
static char ID;
ARM64PromoteConstant() : ModulePass(ID) {}
virtual const char *getPassName() const {
return "ARM64 Promote Constant";
}
bool runOnModule(Module &M) {
DEBUG(dbgs() << getPassName() << '\n');
bool Changed = false;
for (Module::iterator IFn = M.begin(), IEndFn = M.end(); IFn != IEndFn;
++IFn) {
Changed |= runOnFunction(*IFn);
}
return Changed;
}
private:
bool runOnFunction(Function &F);
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
}
typedef SmallVector<Value::use_iterator, 4> Users;
typedef DenseMap<Instruction *, Users> InsertionPoints;
typedef DenseMap<Function *, InsertionPoints> InsertionPointsPerFunc;
Instruction * findInsertionPoint(Value::use_iterator &Use);
bool isDominated(Instruction *NewPt, Value::use_iterator &UseIt,
InsertionPoints &InsertPts);
bool tryAndMerge(Instruction *NewPt, Value::use_iterator &UseIt,
InsertionPoints &InsertPts);
void computeInsertionPoints(Constant *Val,
InsertionPointsPerFunc &InsPtsPerFunc);
bool insertDefinitions(Constant *Cst,
InsertionPointsPerFunc &InsPtsPerFunc);
bool computeAndInsertDefinitions(Constant *Val);
bool promoteConstant(Constant *Cst);
static void appendAndTransferDominatedUses(Instruction *NewPt,
Value::use_iterator &UseIt,
InsertionPoints::iterator &IPI,
InsertionPoints &InsertPts) {
IPI->second.push_back(UseIt);
Instruction *OldInstr = IPI->first;
InsertPts.insert(InsertionPoints::value_type(NewPt, IPI->second));
IPI = InsertPts.find(OldInstr);
InsertPts.erase(IPI);
}
};
}
char ARM64PromoteConstant::ID = 0;
namespace llvm {
void initializeARM64PromoteConstantPass(PassRegistry&);
}
INITIALIZE_PASS_BEGIN(ARM64PromoteConstant, "arm64-promote-const",
"ARM64 Promote Constant Pass", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(ARM64PromoteConstant, "arm64-promote-const",
"ARM64 Promote Constant Pass", false, false)
ModulePass *llvm::createARM64PromoteConstantPass() {
return new ARM64PromoteConstant();
}
static bool isConstantUsingVectorTy(const Type *CstTy) {
if (CstTy->isVectorTy())
return true;
if (CstTy->isStructTy()) {
for (unsigned EltIdx = 0, EndEltIdx = CstTy->getStructNumElements();
EltIdx < EndEltIdx; ++EltIdx)
if (isConstantUsingVectorTy(CstTy->getStructElementType(EltIdx)))
return true;
} else if (CstTy->isArrayTy())
return isConstantUsingVectorTy(CstTy->getArrayElementType());
return false;
}
static bool shouldConvertUse(const Constant *Cst, const Instruction *Instr,
unsigned OpIdx) {
if (isa<const ShuffleVectorInst>(Instr) && OpIdx == 2)
return false;
if (isa<const ExtractValueInst>(Instr) && OpIdx > 0)
return false;
if (isa<const InsertValueInst>(Instr) && OpIdx > 1)
return false;
if (isa<const AllocaInst>(Instr) && OpIdx > 0)
return false;
if (isa<const LoadInst>(Instr) && OpIdx > 0)
return false;
if (isa<const StoreInst>(Instr) && OpIdx > 1)
return false;
if (isa<const GetElementPtrInst>(Instr) && OpIdx > 0)
return false;
if (isa<const LandingPadInst>(Instr))
return false;
if (isa<const SwitchInst>(Instr))
return false;
if (isa<const IndirectBrInst>(Instr))
return false;
if (isa<const IntrinsicInst>(Instr))
return false;
const CallInst *CI = dyn_cast<const CallInst>(Instr);
if (CI && isa<const InlineAsm>(CI->getCalledValue()))
return false;
return true;
}
static bool shouldConvert(const Constant *Cst) {
if (isa<const UndefValue>(Cst))
return false;
if (Cst->isZeroValue())
return false;
if (Stress)
return true;
if (Cst->getType()->isVectorTy())
return false;
return isConstantUsingVectorTy(Cst->getType());
}
Instruction *
ARM64PromoteConstant::findInsertionPoint(Value::use_iterator &Use) {
PHINode *PhiInst = dyn_cast<PHINode>(*Use);
Instruction *InsertionPoint;
if (PhiInst)
InsertionPoint =
PhiInst->getIncomingBlock(Use.getOperandNo())->getTerminator();
else
InsertionPoint = dyn_cast<Instruction>(*Use);
assert(InsertionPoint && "User is not an instruction!");
return InsertionPoint;
}
bool ARM64PromoteConstant::isDominated(Instruction *NewPt,
Value::use_iterator &UseIt,
InsertionPoints &InsertPts) {
DominatorTree &DT =
getAnalysis<DominatorTreeWrapperPass>(*NewPt->getParent()->getParent()).
getDomTree();
for (InsertionPoints::iterator IPI = InsertPts.begin(),
EndIPI = InsertPts.end(); IPI != EndIPI; ++IPI) {
if (NewPt == IPI->first || DT.dominates(IPI->first, NewPt) ||
(IPI->first->getParent() != NewPt->getParent() &&
DT.dominates(IPI->first->getParent(), NewPt->getParent()))) {
DEBUG(dbgs() << "Insertion point dominated by:\n");
DEBUG(IPI->first->print(dbgs()));
DEBUG(dbgs() << '\n');
IPI->second.push_back(UseIt);
return true;
}
}
return false;
}
bool ARM64PromoteConstant::tryAndMerge(Instruction *NewPt,
Value::use_iterator &UseIt,
InsertionPoints &InsertPts) {
DominatorTree &DT =
getAnalysis<DominatorTreeWrapperPass>(*NewPt->getParent()->getParent()).
getDomTree();
BasicBlock *NewBB = NewPt->getParent();
for (InsertionPoints::iterator IPI = InsertPts.begin(),
EndIPI = InsertPts.end(); IPI != EndIPI; ++IPI) {
BasicBlock *CurBB = IPI->first->getParent();
if (NewBB == CurBB) {
DEBUG(dbgs() << "Merge insertion point with:\n");
DEBUG(IPI->first->print(dbgs()));
DEBUG(dbgs() << "\nat considered insertion point.\n");
appendAndTransferDominatedUses(NewPt, UseIt, IPI, InsertPts);
return true;
}
BasicBlock *CommonDominator = DT.findNearestCommonDominator(NewBB, CurBB);
if (!CommonDominator)
continue;
if (CommonDominator != NewBB) {
assert(CommonDominator != CurBB &&
"Instruction has not been rejected during isDominated check!");
NewPt = CommonDominator->getTerminator();
}
DEBUG(dbgs() << "Merge insertion point with:\n");
DEBUG(IPI->first->print(dbgs()));
DEBUG(dbgs() << '\n');
DEBUG(NewPt->print(dbgs()));
DEBUG(dbgs() << '\n');
appendAndTransferDominatedUses(NewPt, UseIt, IPI, InsertPts);
return true;
}
return false;
}
void ARM64PromoteConstant::
computeInsertionPoints(Constant *Val,
InsertionPointsPerFunc &InsPtsPerFunc) {
DEBUG(dbgs() << "** Compute insertion points **\n");
for (Value::use_iterator UseIt = Val->use_begin(), EndUseIt = Val->use_end();
UseIt != EndUseIt; ++UseIt) {
if (!isa<Instruction>(*UseIt))
continue;
if (!shouldConvertUse(Val, cast<Instruction>(*UseIt), UseIt.getOperandNo()))
continue;
DEBUG(dbgs() << "Considered use, opidx " << UseIt.getOperandNo() << ":\n");
DEBUG(UseIt->print(dbgs()));
DEBUG(dbgs() << '\n');
Instruction *InsertionPoint = findInsertionPoint(UseIt);
DEBUG(dbgs() << "Considered insertion point:\n");
DEBUG(InsertionPoint->print(dbgs()));
DEBUG(dbgs() << '\n');
InsertionPoints &InsertPts =
InsPtsPerFunc[InsertionPoint->getParent()->getParent()];
if (isDominated(InsertionPoint, UseIt, InsertPts))
continue;
if (tryAndMerge(InsertionPoint, UseIt, InsertPts))
continue;
DEBUG(dbgs() << "Keep considered insertion point\n");
InsertPts[InsertionPoint].push_back(UseIt);
}
}
bool ARM64PromoteConstant::
insertDefinitions(Constant *Cst,
InsertionPointsPerFunc &InsPtsPerFunc) {
DenseMap<Module *, GlobalVariable *> ModuleToMergedGV;
bool HasChanged = false;
for (InsertionPointsPerFunc::iterator FctToInstPtsIt = InsPtsPerFunc.begin(),
EndIt = InsPtsPerFunc.end(); FctToInstPtsIt != EndIt;
++FctToInstPtsIt) {
InsertionPoints &InsertPts = FctToInstPtsIt->second;
#ifndef NDEBUG
DominatorTree &DT =
getAnalysis<DominatorTreeWrapperPass>(*FctToInstPtsIt->first).
getDomTree();
#endif
GlobalVariable *PromotedGV;
assert(!InsertPts.empty() && "Empty uses does not need a definition");
Module *M = FctToInstPtsIt->first->getParent();
DenseMap<Module *, GlobalVariable *>::iterator MapIt =
ModuleToMergedGV.find(M);
if (MapIt == ModuleToMergedGV.end()) {
PromotedGV = new GlobalVariable(*M, Cst->getType(), true,
GlobalValue::InternalLinkage,
0, "_PromotedConst", 0,
GlobalVariable::NotThreadLocal);
PromotedGV->setInitializer(Cst);
ModuleToMergedGV[M] = PromotedGV;
DEBUG(dbgs() << "Global replacement: ");
DEBUG(PromotedGV->print(dbgs()));
DEBUG(dbgs() << '\n');
++NumPromoted;
HasChanged = true;
} else {
PromotedGV = MapIt->second;
}
for (InsertionPoints::iterator IPI = InsertPts.begin(),
EndIPI = InsertPts.end(); IPI != EndIPI; ++IPI) {
IRBuilder<> Builder(IPI->first->getParent(), IPI->first);
LoadInst *LoadedCst = Builder.CreateLoad(PromotedGV);
DEBUG(dbgs() << "**********\n");
DEBUG(dbgs() << "New def: ");
DEBUG(LoadedCst->print(dbgs()));
DEBUG(dbgs() << '\n');
Users &DominatedUsers = IPI->second;
for (Users::iterator UseIt = DominatedUsers.begin(),
EndIt = DominatedUsers.end(); UseIt != EndIt; ++UseIt) {
#ifndef NDEBUG
assert((DT.dominates(LoadedCst, cast<Instruction>(**UseIt)) ||
(isa<PHINode>(**UseIt) &&
DT.dominates(LoadedCst, findInsertionPoint(*UseIt)))) &&
"Inserted definition does not dominate all its uses!");
#endif
DEBUG(dbgs() << "Use to update " << (*UseIt).getOperandNo() << ":");
DEBUG((*UseIt)->print(dbgs()));
DEBUG(dbgs() << '\n');
(*UseIt)->setOperand((*UseIt).getOperandNo(), LoadedCst);
++NumPromotedUses;
}
}
}
return HasChanged;
}
bool ARM64PromoteConstant::computeAndInsertDefinitions(Constant *Val) {
InsertionPointsPerFunc InsertPtsPerFunc;
computeInsertionPoints(Val, InsertPtsPerFunc);
return insertDefinitions(Val, InsertPtsPerFunc);
}
bool ARM64PromoteConstant::promoteConstant(Constant *Cst) {
assert(Cst && "Given variable is not a valid constant.");
if (!shouldConvert(Cst))
return false;
DEBUG(dbgs() << "******************************\n");
DEBUG(dbgs() << "Candidate constant: ");
DEBUG(Cst->print(dbgs()));
DEBUG(dbgs() << '\n');
return computeAndInsertDefinitions(Cst);
}
bool ARM64PromoteConstant::runOnFunction(Function &F) {
bool LocalChange = false;
SmallSet<Constant *, 8> AlreadyChecked;
for (Function::iterator IBB = F.begin(), IEndBB = F.end();
IBB != IEndBB; ++IBB) {
for (BasicBlock::iterator II = IBB->begin(), IEndI = IBB->end();
II != IEndI; ++II) {
for (unsigned OpIdx = 0, EndOpIdx = II->getNumOperands();
OpIdx != EndOpIdx; ++OpIdx) {
Constant *Cst = dyn_cast<Constant>(II->getOperand(OpIdx));
if (Cst && !isa<GlobalValue>(Cst) && !isa<ConstantExpr>(Cst) &&
AlreadyChecked.insert(Cst))
LocalChange |= promoteConstant(Cst);
}
}
}
return LocalChange;
}