#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/RecyclingAllocator.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Utils/Local.h"
#include <deque>
using namespace llvm;
using namespace llvm::PatternMatch;
#define DEBUG_TYPE "early-cse"
STATISTIC(NumSimplify, "Number of instructions simplified or DCE'd");
STATISTIC(NumCSE, "Number of instructions CSE'd");
STATISTIC(NumCSELoad, "Number of load instructions CSE'd");
STATISTIC(NumCSECall, "Number of call instructions CSE'd");
STATISTIC(NumDSE, "Number of trivial dead stores removed");
namespace {
struct SimpleValue {
Instruction *Inst;
SimpleValue(Instruction *I) : Inst(I) {
assert((isSentinel() || canHandle(I)) && "Inst can't be handled!");
}
bool isSentinel() const {
return Inst == DenseMapInfo<Instruction *>::getEmptyKey() ||
Inst == DenseMapInfo<Instruction *>::getTombstoneKey();
}
static bool canHandle(Instruction *Inst) {
if (CallInst *CI = dyn_cast<CallInst>(Inst))
return CI->doesNotAccessMemory() && !CI->getType()->isVoidTy();
return isa<CastInst>(Inst) || isa<BinaryOperator>(Inst) ||
isa<GetElementPtrInst>(Inst) || isa<CmpInst>(Inst) ||
isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) ||
isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) ||
isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst);
}
};
}
namespace llvm {
template <> struct DenseMapInfo<SimpleValue> {
static inline SimpleValue getEmptyKey() {
return DenseMapInfo<Instruction *>::getEmptyKey();
}
static inline SimpleValue getTombstoneKey() {
return DenseMapInfo<Instruction *>::getTombstoneKey();
}
static unsigned getHashValue(SimpleValue Val);
static bool isEqual(SimpleValue LHS, SimpleValue RHS);
};
}
unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) {
Instruction *Inst = Val.Inst;
if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst)) {
Value *LHS = BinOp->getOperand(0);
Value *RHS = BinOp->getOperand(1);
if (BinOp->isCommutative() && BinOp->getOperand(0) > BinOp->getOperand(1))
std::swap(LHS, RHS);
if (isa<OverflowingBinaryOperator>(BinOp)) {
unsigned Overflow =
BinOp->hasNoSignedWrap() * OverflowingBinaryOperator::NoSignedWrap |
BinOp->hasNoUnsignedWrap() *
OverflowingBinaryOperator::NoUnsignedWrap;
return hash_combine(BinOp->getOpcode(), Overflow, LHS, RHS);
}
return hash_combine(BinOp->getOpcode(), LHS, RHS);
}
if (CmpInst *CI = dyn_cast<CmpInst>(Inst)) {
Value *LHS = CI->getOperand(0);
Value *RHS = CI->getOperand(1);
CmpInst::Predicate Pred = CI->getPredicate();
if (Inst->getOperand(0) > Inst->getOperand(1)) {
std::swap(LHS, RHS);
Pred = CI->getSwappedPredicate();
}
return hash_combine(Inst->getOpcode(), Pred, LHS, RHS);
}
if (CastInst *CI = dyn_cast<CastInst>(Inst))
return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0));
if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(Inst))
return hash_combine(EVI->getOpcode(), EVI->getOperand(0),
hash_combine_range(EVI->idx_begin(), EVI->idx_end()));
if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(Inst))
return hash_combine(IVI->getOpcode(), IVI->getOperand(0),
IVI->getOperand(1),
hash_combine_range(IVI->idx_begin(), IVI->idx_end()));
assert((isa<CallInst>(Inst) || isa<BinaryOperator>(Inst) ||
isa<GetElementPtrInst>(Inst) || isa<SelectInst>(Inst) ||
isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) ||
isa<ShuffleVectorInst>(Inst)) &&
"Invalid/unknown instruction");
return hash_combine(
Inst->getOpcode(),
hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
}
bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) {
Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;
if (LHS.isSentinel() || RHS.isSentinel())
return LHSI == RHSI;
if (LHSI->getOpcode() != RHSI->getOpcode())
return false;
if (LHSI->isIdenticalTo(RHSI))
return true;
if (BinaryOperator *LHSBinOp = dyn_cast<BinaryOperator>(LHSI)) {
if (!LHSBinOp->isCommutative())
return false;
assert(isa<BinaryOperator>(RHSI) &&
"same opcode, but different instruction type?");
BinaryOperator *RHSBinOp = cast<BinaryOperator>(RHSI);
if (isa<OverflowingBinaryOperator>(LHSBinOp)) {
assert(isa<OverflowingBinaryOperator>(RHSBinOp) &&
"same opcode, but different operator type?");
if (LHSBinOp->hasNoUnsignedWrap() != RHSBinOp->hasNoUnsignedWrap() ||
LHSBinOp->hasNoSignedWrap() != RHSBinOp->hasNoSignedWrap())
return false;
}
return LHSBinOp->getOperand(0) == RHSBinOp->getOperand(1) &&
LHSBinOp->getOperand(1) == RHSBinOp->getOperand(0);
}
if (CmpInst *LHSCmp = dyn_cast<CmpInst>(LHSI)) {
assert(isa<CmpInst>(RHSI) &&
"same opcode, but different instruction type?");
CmpInst *RHSCmp = cast<CmpInst>(RHSI);
return LHSCmp->getOperand(0) == RHSCmp->getOperand(1) &&
LHSCmp->getOperand(1) == RHSCmp->getOperand(0) &&
LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate();
}
return false;
}
namespace {
struct CallValue {
Instruction *Inst;
CallValue(Instruction *I) : Inst(I) {
assert((isSentinel() || canHandle(I)) && "Inst can't be handled!");
}
bool isSentinel() const {
return Inst == DenseMapInfo<Instruction *>::getEmptyKey() ||
Inst == DenseMapInfo<Instruction *>::getTombstoneKey();
}
static bool canHandle(Instruction *Inst) {
if (Inst->getType()->isVoidTy())
return false;
CallInst *CI = dyn_cast<CallInst>(Inst);
if (!CI || !CI->onlyReadsMemory())
return false;
return true;
}
};
}
namespace llvm {
template <> struct DenseMapInfo<CallValue> {
static inline CallValue getEmptyKey() {
return DenseMapInfo<Instruction *>::getEmptyKey();
}
static inline CallValue getTombstoneKey() {
return DenseMapInfo<Instruction *>::getTombstoneKey();
}
static unsigned getHashValue(CallValue Val);
static bool isEqual(CallValue LHS, CallValue RHS);
};
}
unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) {
Instruction *Inst = Val.Inst;
return hash_combine(
Inst->getOpcode(),
hash_combine_range(Inst->value_op_begin(), Inst->value_op_end()));
}
bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) {
Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst;
if (LHS.isSentinel() || RHS.isSentinel())
return LHSI == RHSI;
return LHSI->isIdenticalTo(RHSI);
}
namespace {
class EarlyCSE {
public:
Function &F;
const TargetLibraryInfo &TLI;
const TargetTransformInfo &TTI;
DominatorTree &DT;
AssumptionCache &AC;
typedef RecyclingAllocator<
BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value *>> AllocatorTy;
typedef ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>,
AllocatorTy> ScopedHTType;
ScopedHTType AvailableValues;
typedef RecyclingAllocator<
BumpPtrAllocator,
ScopedHashTableVal<Value *, std::pair<Value *, unsigned>>>
LoadMapAllocator;
typedef ScopedHashTable<Value *, std::pair<Value *, unsigned>,
DenseMapInfo<Value *>, LoadMapAllocator> LoadHTType;
LoadHTType AvailableLoads;
typedef ScopedHashTable<CallValue, std::pair<Value *, unsigned>> CallHTType;
CallHTType AvailableCalls;
unsigned CurrentGeneration;
EarlyCSE(Function &F, const TargetLibraryInfo &TLI,
const TargetTransformInfo &TTI, DominatorTree &DT,
AssumptionCache &AC)
: F(F), TLI(TLI), TTI(TTI), DT(DT), AC(AC), CurrentGeneration(0) {}
bool run();
private:
class NodeScope {
public:
NodeScope(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads,
CallHTType &AvailableCalls)
: Scope(AvailableValues), LoadScope(AvailableLoads),
CallScope(AvailableCalls) {}
private:
NodeScope(const NodeScope &) LLVM_DELETED_FUNCTION;
void operator=(const NodeScope &) LLVM_DELETED_FUNCTION;
ScopedHTType::ScopeTy Scope;
LoadHTType::ScopeTy LoadScope;
CallHTType::ScopeTy CallScope;
};
class StackNode {
public:
StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads,
CallHTType &AvailableCalls, unsigned cg, DomTreeNode *n,
DomTreeNode::iterator child, DomTreeNode::iterator end)
: CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child),
EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableCalls),
Processed(false) {}
unsigned currentGeneration() { return CurrentGeneration; }
unsigned childGeneration() { return ChildGeneration; }
void childGeneration(unsigned generation) { ChildGeneration = generation; }
DomTreeNode *node() { return Node; }
DomTreeNode::iterator childIter() { return ChildIter; }
DomTreeNode *nextChild() {
DomTreeNode *child = *ChildIter;
++ChildIter;
return child;
}
DomTreeNode::iterator end() { return EndIter; }
bool isProcessed() { return Processed; }
void process() { Processed = true; }
private:
StackNode(const StackNode &) LLVM_DELETED_FUNCTION;
void operator=(const StackNode &) LLVM_DELETED_FUNCTION;
unsigned CurrentGeneration;
unsigned ChildGeneration;
DomTreeNode *Node;
DomTreeNode::iterator ChildIter;
DomTreeNode::iterator EndIter;
NodeScope Scopes;
bool Processed;
};
class ParseMemoryInst {
public:
ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI)
: Load(false), Store(false), Vol(false), MayReadFromMemory(false),
MayWriteToMemory(false), MatchingId(-1), Ptr(nullptr) {
MayReadFromMemory = Inst->mayReadFromMemory();
MayWriteToMemory = Inst->mayWriteToMemory();
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) {
MemIntrinsicInfo Info;
if (!TTI.getTgtMemIntrinsic(II, Info))
return;
if (Info.NumMemRefs == 1) {
Store = Info.WriteMem;
Load = Info.ReadMem;
MatchingId = Info.MatchingId;
MayReadFromMemory = Info.ReadMem;
MayWriteToMemory = Info.WriteMem;
Vol = Info.Vol;
Ptr = Info.PtrVal;
}
} else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
Load = true;
Vol = !LI->isSimple();
Ptr = LI->getPointerOperand();
} else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
Store = true;
Vol = !SI->isSimple();
Ptr = SI->getPointerOperand();
}
}
bool isLoad() { return Load; }
bool isStore() { return Store; }
bool isVolatile() { return Vol; }
bool isMatchingMemLoc(const ParseMemoryInst &Inst) {
return Ptr == Inst.Ptr && MatchingId == Inst.MatchingId;
}
bool isValid() { return Ptr != nullptr; }
int getMatchingId() { return MatchingId; }
Value *getPtr() { return Ptr; }
bool mayReadFromMemory() { return MayReadFromMemory; }
bool mayWriteToMemory() { return MayWriteToMemory; }
private:
bool Load;
bool Store;
bool Vol;
bool MayReadFromMemory;
bool MayWriteToMemory;
int MatchingId;
Value *Ptr;
};
bool processNode(DomTreeNode *Node);
Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const {
if (LoadInst *LI = dyn_cast<LoadInst>(Inst))
return LI;
else if (StoreInst *SI = dyn_cast<StoreInst>(Inst))
return SI->getValueOperand();
assert(isa<IntrinsicInst>(Inst) && "Instruction not supported");
return TTI.getOrCreateResultFromMemIntrinsic(cast<IntrinsicInst>(Inst),
ExpectedType);
}
};
}
bool EarlyCSE::processNode(DomTreeNode *Node) {
BasicBlock *BB = Node->getBlock();
if (!BB->getSinglePredecessor())
++CurrentGeneration;
Instruction *LastStore = nullptr;
bool Changed = false;
const DataLayout &DL = BB->getModule()->getDataLayout();
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) {
Instruction *Inst = I++;
if (isInstructionTriviallyDead(Inst, &TLI)) {
DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n');
Inst->eraseFromParent();
Changed = true;
++NumSimplify;
continue;
}
if (match(Inst, m_Intrinsic<Intrinsic::assume>())) {
DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n');
continue;
}
if (Value *V = SimplifyInstruction(Inst, DL, &TLI, &DT, &AC)) {
DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n');
Inst->replaceAllUsesWith(V);
Inst->eraseFromParent();
Changed = true;
++NumSimplify;
continue;
}
if (SimpleValue::canHandle(Inst)) {
if (Value *V = AvailableValues.lookup(Inst)) {
DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V << '\n');
Inst->replaceAllUsesWith(V);
Inst->eraseFromParent();
Changed = true;
++NumCSE;
continue;
}
AvailableValues.insert(Inst, Inst);
continue;
}
ParseMemoryInst MemInst(Inst, TTI);
if (MemInst.isValid() && MemInst.isLoad()) {
if (MemInst.isVolatile()) {
LastStore = nullptr;
if (Inst->mayWriteToMemory())
++CurrentGeneration;
continue;
}
std::pair<Value *, unsigned> InVal =
AvailableLoads.lookup(MemInst.getPtr());
if (InVal.first != nullptr && InVal.second == CurrentGeneration) {
Value *Op = getOrCreateResult(InVal.first, Inst->getType());
if (Op != nullptr) {
DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst
<< " to: " << *InVal.first << '\n');
if (!Inst->use_empty())
Inst->replaceAllUsesWith(Op);
Inst->eraseFromParent();
Changed = true;
++NumCSELoad;
continue;
}
}
AvailableLoads.insert(MemInst.getPtr(), std::pair<Value *, unsigned>(
Inst, CurrentGeneration));
LastStore = nullptr;
continue;
}
if (Inst->mayReadFromMemory() &&
!(MemInst.isValid() && !MemInst.mayReadFromMemory()))
LastStore = nullptr;
if (CallValue::canHandle(Inst)) {
std::pair<Value *, unsigned> InVal = AvailableCalls.lookup(Inst);
if (InVal.first != nullptr && InVal.second == CurrentGeneration) {
DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst
<< " to: " << *InVal.first << '\n');
if (!Inst->use_empty())
Inst->replaceAllUsesWith(InVal.first);
Inst->eraseFromParent();
Changed = true;
++NumCSECall;
continue;
}
AvailableCalls.insert(
Inst, std::pair<Value *, unsigned>(Inst, CurrentGeneration));
continue;
}
if (Inst->mayWriteToMemory()) {
++CurrentGeneration;
if (MemInst.isValid() && MemInst.isStore()) {
if (LastStore) {
ParseMemoryInst LastStoreMemInst(LastStore, TTI);
if (LastStoreMemInst.isMatchingMemLoc(MemInst)) {
DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore
<< " due to: " << *Inst << '\n');
LastStore->eraseFromParent();
Changed = true;
++NumDSE;
LastStore = nullptr;
}
}
AvailableLoads.insert(MemInst.getPtr(), std::pair<Value *, unsigned>(
Inst, CurrentGeneration));
if (!MemInst.isVolatile())
LastStore = Inst;
}
}
}
return Changed;
}
bool EarlyCSE::run() {
std::deque<StackNode *> nodesToProcess;
bool Changed = false;
nodesToProcess.push_back(new StackNode(
AvailableValues, AvailableLoads, AvailableCalls, CurrentGeneration,
DT.getRootNode(), DT.getRootNode()->begin(), DT.getRootNode()->end()));
unsigned LiveOutGeneration = CurrentGeneration;
while (!nodesToProcess.empty()) {
StackNode *NodeToProcess = nodesToProcess.back();
CurrentGeneration = NodeToProcess->currentGeneration();
if (!NodeToProcess->isProcessed()) {
Changed |= processNode(NodeToProcess->node());
NodeToProcess->childGeneration(CurrentGeneration);
NodeToProcess->process();
} else if (NodeToProcess->childIter() != NodeToProcess->end()) {
DomTreeNode *child = NodeToProcess->nextChild();
nodesToProcess.push_back(
new StackNode(AvailableValues, AvailableLoads, AvailableCalls,
NodeToProcess->childGeneration(), child, child->begin(),
child->end()));
} else {
delete NodeToProcess;
nodesToProcess.pop_back();
}
}
CurrentGeneration = LiveOutGeneration;
return Changed;
}
PreservedAnalyses EarlyCSEPass::run(Function &F,
AnalysisManager<Function> *AM) {
auto &TLI = AM->getResult<TargetLibraryAnalysis>(F);
auto &TTI = AM->getResult<TargetIRAnalysis>(F);
auto &DT = AM->getResult<DominatorTreeAnalysis>(F);
auto &AC = AM->getResult<AssumptionAnalysis>(F);
EarlyCSE CSE(F, TLI, TTI, DT, AC);
if (!CSE.run())
return PreservedAnalyses::all();
PreservedAnalyses PA;
PA.preserve<DominatorTreeAnalysis>();
return PA;
}
namespace {
class EarlyCSELegacyPass : public FunctionPass {
public:
static char ID;
EarlyCSELegacyPass() : FunctionPass(ID) {
initializeEarlyCSELegacyPassPass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override {
if (skipOptnoneFunction(F))
return false;
auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
EarlyCSE CSE(F, TLI, TTI, DT, AC);
return CSE.run();
}
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<AssumptionCacheTracker>();
AU.addRequired<DominatorTreeWrapperPass>();
AU.addRequired<TargetLibraryInfoWrapperPass>();
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.setPreservesCFG();
}
};
}
char EarlyCSELegacyPass::ID = 0;
FunctionPass *llvm::createEarlyCSEPass() { return new EarlyCSELegacyPass(); }
INITIALIZE_PASS_BEGIN(EarlyCSELegacyPass, "early-cse", "Early CSE", false,
false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
INITIALIZE_PASS_END(EarlyCSELegacyPass, "early-cse", "Early CSE", false, false)