ARM64AddressTypePromotion.cpp [plain text]
#define DEBUG_TYPE "arm64-type-promotion"
#include "ARM64.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/Pass.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
using namespace llvm;
static cl::opt<bool>
EnableAddressTypePromotion("arm64-type-promotion", cl::Hidden,
cl::desc("Enable the type promotion pass"),
cl::init(true));
static cl::opt<bool>
EnableMerge("arm64-type-promotion-merge", cl::Hidden,
cl::desc("Enable merging of redundant sexts when one is dominating"
" the other."),
cl::init(true));
namespace llvm {
void initializeARM64AddressTypePromotionPass(PassRegistry&);
}
namespace {
class ARM64AddressTypePromotion : public FunctionPass {
public:
static char ID;
ARM64AddressTypePromotion() : FunctionPass(ID), Func(NULL),
ConsideredSExtType(NULL) {
initializeARM64AddressTypePromotionPass(*PassRegistry::getPassRegistry());
}
virtual const char *getPassName() const {
return "ARM64 Address Type Promotion";
}
bool runOnFunction(Function &F);
private:
Function *Func;
Type *ConsideredSExtType;
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
FunctionPass::getAnalysisUsage(AU);
}
typedef SmallPtrSet<Instruction *, 32> SetOfInstructions;
typedef SmallVector<Instruction *, 16> Instructions;
typedef DenseMap<Value *, Instructions> ValueToInsts;
bool shouldGetThrough(const Instruction *Inst);
bool canGetThrough(const Instruction *Inst);
bool propagateSignExtension(Instructions &SExtInsts);
bool shouldConsiderSExt(const Instruction *SExt) const;
void analyzeSExtension(Instructions &SExtInsts);
void mergeSExts(ValueToInsts &ValToSExtendedUses,
SetOfInstructions &ToRemove);
};
}
char ARM64AddressTypePromotion::ID = 0;
INITIALIZE_PASS_BEGIN(ARM64AddressTypePromotion, "arm64-type-promotion",
"ARM64 Type Promotion Pass", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_END(ARM64AddressTypePromotion, "arm64-type-promotion",
"ARM64 Type Promotion Pass", false, false)
FunctionPass *llvm::createARM64AddressTypePromotionPass() {
return new ARM64AddressTypePromotion();
}
bool ARM64AddressTypePromotion::canGetThrough(const Instruction *Inst) {
if (isa<SExtInst>(Inst))
return true;
const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
if (BinOp && isa<OverflowingBinaryOperator>(BinOp) &&
(BinOp->hasNoUnsignedWrap() || BinOp->hasNoSignedWrap()))
return true;
if (isa<TruncInst>(Inst) && isa<SExtInst>(Inst->getOperand(0))) {
const Instruction *Opnd = cast<Instruction>(Inst->getOperand(0));
if (Inst->getType()->getIntegerBitWidth() >=
Opnd->getOperand(0)->getType()->getIntegerBitWidth() &&
Inst->getOperand(0)->getType()->getIntegerBitWidth() <=
ConsideredSExtType->getIntegerBitWidth())
return true;
}
return false;
}
bool ARM64AddressTypePromotion::shouldGetThrough(const Instruction *Inst) {
if (isa<SExtInst>(Inst) &&
(Inst->getType() == ConsideredSExtType || Inst->hasOneUse()))
return true;
if (!Inst->hasOneUse())
return false;
if(isa<TruncInst>(Inst))
return true;
if (isa<BinaryOperator>(Inst) && isa<ConstantInt>(Inst->getOperand(1)))
return true;
return false;
}
static bool shouldSExtOperand(const Instruction *Inst, int OpIdx) {
if (isa<SelectInst>(Inst) && OpIdx == 0)
return false;
return true;
}
bool
ARM64AddressTypePromotion::shouldConsiderSExt(const Instruction *SExt) const {
if (SExt->getType() != ConsideredSExtType)
return false;
for (Value::const_use_iterator UseIt = SExt->use_begin(),
EndUseIt = SExt->use_end(); UseIt != EndUseIt; ++UseIt) {
if (isa<GetElementPtrInst>(*UseIt))
return true;
}
return false;
}
bool ARM64AddressTypePromotion::
propagateSignExtension(Instructions &SExtInsts) {
DEBUG(dbgs() << "*** Propagate Sign Extension ***\n");
bool LocalChange = false;
SetOfInstructions ToRemove;
ValueToInsts ValToSExtendedUses;
while (!SExtInsts.empty()) {
Instruction *SExt = SExtInsts.pop_back_val();
DEBUG(dbgs() << "Consider:\n" << *SExt << '\n');
if (SExt->use_empty() && ToRemove.count(SExt)) {
DEBUG(dbgs() << "No uses => marked as delete\n");
continue;
}
while (isa<Instruction>(SExt->getOperand(0))) {
Instruction *Inst = dyn_cast<Instruction>(SExt->getOperand(0));
DEBUG(dbgs() << "Try to get through:\n" << *Inst << '\n');
if (!canGetThrough(Inst) || !shouldGetThrough(Inst)) {
DEBUG(dbgs() << "Cannot get through\n");
break;
}
LocalChange = true;
if (isa<SExtInst>(Inst) || isa<TruncInst>(Inst)) {
DEBUG(dbgs() << "SExt or trunc, mark it as to remove\n");
while (!Inst->use_empty()) {
Value::use_iterator UseIt = Inst->use_begin();
Instruction *UseInst = dyn_cast<Instruction>(*UseIt);
assert(UseInst && "Use of sext is not an Instruction!");
UseInst->setOperand(UseIt.getOperandNo(), SExt);
}
ToRemove.insert(Inst);
SExt->setOperand(0, Inst->getOperand(0));
SExt->moveBefore(Inst);
continue;
}
Inst->mutateType(SExt->getType());
SExt->replaceAllUsesWith(Inst);
Instruction *SExtForOpnd = SExt;
DEBUG(dbgs() << "Propagate SExt to operands\n");
for (int OpIdx = 0, EndOpIdx = Inst->getNumOperands(); OpIdx != EndOpIdx;
++OpIdx) {
DEBUG(dbgs() << "Operand:\n" << *(Inst->getOperand(OpIdx)) << '\n');
if (Inst->getOperand(OpIdx)->getType() == SExt->getType() ||
!shouldSExtOperand(Inst, OpIdx)) {
DEBUG(dbgs() << "No need to propagate\n");
continue;
}
Value *Opnd = Inst->getOperand(OpIdx);
if (const ConstantInt *Cst = dyn_cast<ConstantInt>(Opnd)) {
DEBUG(dbgs() << "Statically sign extend\n");
Inst->setOperand(OpIdx, ConstantInt::getSigned(SExt->getType(),
Cst->getSExtValue()));
continue;
}
if (isa<UndefValue>(Opnd)) {
DEBUG(dbgs() << "Statically sign extend\n");
Inst->setOperand(OpIdx, UndefValue::get(SExt->getType()));
continue;
}
assert (SExtForOpnd &&
"Only one operand should have been sign extended");
SExtForOpnd->setOperand(0, Opnd);
DEBUG(dbgs() << "Move before:\n" << *Inst <<
"\nSign extend\n");
SExtForOpnd->moveBefore(Inst);
Inst->setOperand(OpIdx, SExtForOpnd);
SExtForOpnd = NULL;
}
if (SExtForOpnd == SExt) {
DEBUG(dbgs() << "Sign extension is useless now\n");
ToRemove.insert(SExt);
break;
}
}
if (!ToRemove.count(SExt) &&
SExt->getType() == SExt->getOperand(0)->getType()) {
DEBUG(dbgs() << "Sign extension is useless, attach its use to "
"its argument\n");
SExt->replaceAllUsesWith(SExt->getOperand(0));
ToRemove.insert(SExt);
} else
ValToSExtendedUses[SExt->getOperand(0)].push_back(SExt);
}
if (EnableMerge)
mergeSExts(ValToSExtendedUses, ToRemove);
for (SetOfInstructions::iterator ToRemoveIt = ToRemove.begin(),
EndToRemoveIt = ToRemove.end(); ToRemoveIt != EndToRemoveIt;
++ToRemoveIt)
(*ToRemoveIt)->eraseFromParent();
return LocalChange;
}
void ARM64AddressTypePromotion::mergeSExts(ValueToInsts &ValToSExtendedUses,
SetOfInstructions &ToRemove) {
DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
for (ValueToInsts::iterator It = ValToSExtendedUses.begin(),
EndIt = ValToSExtendedUses.end(); It != EndIt; ++It) {
Instructions &Insts = It->second;
Instructions CurPts;
for (Instructions::iterator IIt = Insts.begin(), EndIIt = Insts.end();
IIt != EndIIt; ++IIt) {
if (ToRemove.count(*IIt))
continue;
bool inserted = false;
for (Instructions::iterator CurPtsIt = CurPts.begin(),
EndCurPtsIt = CurPts.end(); CurPtsIt != EndCurPtsIt; ++CurPtsIt) {
if (DT.dominates(*IIt, *CurPtsIt)) {
DEBUG(dbgs() << "Replace all uses of:\n" << **CurPtsIt <<
"\nwith:\n" << **IIt << '\n');
(*CurPtsIt)->replaceAllUsesWith(*IIt);
ToRemove.insert(*CurPtsIt);
*CurPtsIt = *IIt;
inserted = true;
break;
}
if (!DT.dominates(*CurPtsIt, *IIt))
continue;
DEBUG(dbgs() << "Replace all uses of:\n" << **IIt <<
"\nwith:\n" << **CurPtsIt << '\n');
(*IIt)->replaceAllUsesWith(*CurPtsIt);
ToRemove.insert(*IIt);
inserted = true;
break;
}
if (!inserted)
CurPts.push_back(*IIt);
}
}
}
void ARM64AddressTypePromotion::analyzeSExtension(Instructions &SExtInsts) {
DEBUG(dbgs() << "*** Analyze Sign Extensions ***\n");
DenseMap<Value *, Instruction *> SeenChains;
for (Function::iterator IBB = Func->begin(), IEndBB = Func->end();
IBB != IEndBB; ++IBB) {
for (BasicBlock::iterator II = IBB->begin(), IEndI = IBB->end();
II != IEndI; ++II) {
if (!isa<SExtInst>(II) || !shouldConsiderSExt(II))
continue;
Instruction *SExt = II;
DEBUG(dbgs() << "Found:\n" << (*II) <<'\n');
bool insert = false;
for (Value::use_iterator UseIt = SExt->use_begin(),
EndUseIt = SExt->use_end(); UseIt != EndUseIt; ++UseIt) {
const Instruction *Inst = dyn_cast<GetElementPtrInst>(*UseIt);
if (Inst && Inst->getNumOperands() > 2) {
DEBUG(dbgs() << "Interesting use in GetElementPtrInst\n" <<
*Inst << '\n');
insert = true;
break;
}
}
Instruction *Inst = SExt;
Value *Last;
do {
int OpdIdx = 0;
const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst);
if (BinOp && isa<ConstantInt>(BinOp->getOperand(0)))
OpdIdx = 1;
Last = Inst->getOperand(OpdIdx);
Inst = dyn_cast<Instruction>(Last);
} while (Inst && canGetThrough(Inst) && shouldGetThrough(Inst));
DEBUG(dbgs() << "Head of the chain:\n" << *Last << '\n');
DenseMap<Value *, Instruction *>::iterator AlreadySeen =
SeenChains.find(Last);
if (insert || AlreadySeen != SeenChains.end()) {
DEBUG(dbgs() << "Insert\n");
SExtInsts.push_back(II);
if (AlreadySeen != SeenChains.end() && AlreadySeen->second != NULL) {
DEBUG(dbgs() << "Insert chain member\n");
SExtInsts.push_back(AlreadySeen->second);
SeenChains[Last] = NULL;
}
} else {
DEBUG(dbgs() << "Record its chain membership\n");
SeenChains[Last] = SExt;
}
}
}
}
bool ARM64AddressTypePromotion::runOnFunction(Function &F) {
if (!EnableAddressTypePromotion || F.isDeclaration())
return false;
Func = &F;
ConsideredSExtType = Type::getInt64Ty(Func->getContext());
DEBUG(dbgs() << "*** " << getPassName() << ": " << Func->getName() << '\n' );
Instructions SExtInsts;
analyzeSExtension(SExtInsts);
return propagateSignExtension(SExtInsts);
}