#define DEBUG_TYPE "reassociate"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Constants.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Function.h"
#include "llvm/Instructions.h"
#include "llvm/IntrinsicInst.h"
#include "llvm/Pass.h"
#include "llvm/Assembly/Writer.h"
#include "llvm/Support/CFG.h"
#include "llvm/Support/Compiler.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ValueHandle.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/Statistic.h"
#include <algorithm>
#include <map>
using namespace llvm;
STATISTIC(NumLinear , "Number of insts linearized");
STATISTIC(NumChanged, "Number of insts reassociated");
STATISTIC(NumAnnihil, "Number of expr tree annihilated");
STATISTIC(NumFactor , "Number of multiplies factored");
namespace {
struct VISIBILITY_HIDDEN ValueEntry {
unsigned Rank;
Value *Op;
ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {}
};
inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) {
return LHS.Rank > RHS.Rank; }
}
#ifndef NDEBUG
static void PrintOps(Instruction *I, const std::vector<ValueEntry> &Ops) {
Module *M = I->getParent()->getParent()->getParent();
cerr << Instruction::getOpcodeName(I->getOpcode()) << " "
<< *Ops[0].Op->getType();
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
WriteAsOperand(*cerr.stream() << " ", Ops[i].Op, false, M);
cerr << "," << Ops[i].Rank;
}
}
#endif
namespace {
class VISIBILITY_HIDDEN Reassociate : public FunctionPass {
std::map<BasicBlock*, unsigned> RankMap;
std::map<AssertingVH<>, unsigned> ValueRankMap;
bool MadeChange;
public:
static char ID; Reassociate() : FunctionPass(&ID) {}
bool runOnFunction(Function &F);
virtual void getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
}
private:
void BuildRankMap(Function &F);
unsigned getRank(Value *V);
void ReassociateExpression(BinaryOperator *I);
void RewriteExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops,
unsigned Idx = 0);
Value *OptimizeExpression(BinaryOperator *I, std::vector<ValueEntry> &Ops);
void LinearizeExprTree(BinaryOperator *I, std::vector<ValueEntry> &Ops);
void LinearizeExpr(BinaryOperator *I);
Value *RemoveFactorFromExpression(Value *V, Value *Factor);
void ReassociateBB(BasicBlock *BB);
void RemoveDeadBinaryOp(Value *V);
};
}
char Reassociate::ID = 0;
static RegisterPass<Reassociate> X("reassociate", "Reassociate expressions");
FunctionPass *llvm::createReassociatePass() { return new Reassociate(); }
void Reassociate::RemoveDeadBinaryOp(Value *V) {
Instruction *Op = dyn_cast<Instruction>(V);
if (!Op || !isa<BinaryOperator>(Op) || !isa<CmpInst>(Op) || !Op->use_empty())
return;
Value *LHS = Op->getOperand(0), *RHS = Op->getOperand(1);
RemoveDeadBinaryOp(LHS);
RemoveDeadBinaryOp(RHS);
}
static bool isUnmovableInstruction(Instruction *I) {
if (I->getOpcode() == Instruction::PHI ||
I->getOpcode() == Instruction::Alloca ||
I->getOpcode() == Instruction::Load ||
I->getOpcode() == Instruction::Malloc ||
I->getOpcode() == Instruction::Invoke ||
(I->getOpcode() == Instruction::Call &&
!isa<DbgInfoIntrinsic>(I)) ||
I->getOpcode() == Instruction::UDiv ||
I->getOpcode() == Instruction::SDiv ||
I->getOpcode() == Instruction::FDiv ||
I->getOpcode() == Instruction::URem ||
I->getOpcode() == Instruction::SRem ||
I->getOpcode() == Instruction::FRem)
return true;
return false;
}
void Reassociate::BuildRankMap(Function &F) {
unsigned i = 2;
for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I)
ValueRankMap[&*I] = ++i;
ReversePostOrderTraversal<Function*> RPOT(&F);
for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(),
E = RPOT.end(); I != E; ++I) {
BasicBlock *BB = *I;
unsigned BBRank = RankMap[BB] = ++i << 16;
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
if (isUnmovableInstruction(I))
ValueRankMap[&*I] = ++BBRank;
}
}
unsigned Reassociate::getRank(Value *V) {
if (isa<Argument>(V)) return ValueRankMap[V];
Instruction *I = dyn_cast<Instruction>(V);
if (I == 0) return 0;
unsigned &CachedRank = ValueRankMap[I];
if (CachedRank) return CachedRank;
unsigned Rank = 0, MaxRank = RankMap[I->getParent()];
for (unsigned i = 0, e = I->getNumOperands();
i != e && Rank != MaxRank; ++i)
Rank = std::max(Rank, getRank(I->getOperand(i)));
if (!I->getType()->isInteger() ||
(!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I)))
++Rank;
return CachedRank = Rank;
}
static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) {
if ((V->hasOneUse() || V->use_empty()) && isa<Instruction>(V) &&
cast<Instruction>(V)->getOpcode() == Opcode)
return cast<BinaryOperator>(V);
return 0;
}
static Instruction *LowerNegateToMultiply(Instruction *Neg,
std::map<AssertingVH<>, unsigned> &ValueRankMap) {
Constant *Cst = ConstantInt::getAllOnesValue(Neg->getType());
Instruction *Res = BinaryOperator::CreateMul(Neg->getOperand(1), Cst, "",Neg);
ValueRankMap.erase(Neg);
Res->takeName(Neg);
Neg->replaceAllUsesWith(Res);
Neg->eraseFromParent();
return Res;
}
void Reassociate::LinearizeExpr(BinaryOperator *I) {
BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0));
BinaryOperator *RHS = cast<BinaryOperator>(I->getOperand(1));
assert(isReassociableOp(LHS, I->getOpcode()) &&
isReassociableOp(RHS, I->getOpcode()) &&
"Not an expression that needs linearization?");
DOUT << "Linear" << *LHS << *RHS << *I;
RHS->moveBefore(I);
I->setOperand(1, RHS->getOperand(0));
RHS->setOperand(0, LHS);
I->setOperand(0, RHS);
++NumLinear;
MadeChange = true;
DOUT << "Linearized: " << *I;
if (isReassociableOp(I->getOperand(1), I->getOpcode()))
LinearizeExpr(I);
}
void Reassociate::LinearizeExprTree(BinaryOperator *I,
std::vector<ValueEntry> &Ops) {
Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
unsigned Opcode = I->getOpcode();
BinaryOperator *LHSBO = isReassociableOp(LHS, Opcode);
BinaryOperator *RHSBO = isReassociableOp(RHS, Opcode);
if (I->getOpcode() == Instruction::Mul) {
if (!LHSBO && LHS->hasOneUse() && BinaryOperator::isNeg(LHS)) {
LHS = LowerNegateToMultiply(cast<Instruction>(LHS), ValueRankMap);
LHSBO = isReassociableOp(LHS, Opcode);
}
if (!RHSBO && RHS->hasOneUse() && BinaryOperator::isNeg(RHS)) {
RHS = LowerNegateToMultiply(cast<Instruction>(RHS), ValueRankMap);
RHSBO = isReassociableOp(RHS, Opcode);
}
}
if (!LHSBO) {
if (!RHSBO) {
Ops.push_back(ValueEntry(getRank(LHS), LHS));
Ops.push_back(ValueEntry(getRank(RHS), RHS));
I->setOperand(0, UndefValue::get(I->getType()));
I->setOperand(1, UndefValue::get(I->getType()));
return;
} else {
std::swap(LHSBO, RHSBO);
std::swap(LHS, RHS);
bool Success = !I->swapOperands();
assert(Success && "swapOperands failed");
Success = false;
MadeChange = true;
}
} else if (RHSBO) {
LinearizeExpr(I);
LHS = LHSBO = cast<BinaryOperator>(I->getOperand(0));
RHS = I->getOperand(1);
RHSBO = 0;
}
assert(!isReassociableOp(RHS, Opcode) && "LinearizeExpr failed!");
LHSBO->moveBefore(I);
LinearizeExprTree(LHSBO, Ops);
Ops.push_back(ValueEntry(getRank(RHS), RHS));
I->setOperand(1, UndefValue::get(I->getType()));
}
void Reassociate::RewriteExprTree(BinaryOperator *I,
std::vector<ValueEntry> &Ops,
unsigned i) {
if (i+2 == Ops.size()) {
if (I->getOperand(0) != Ops[i].Op ||
I->getOperand(1) != Ops[i+1].Op) {
Value *OldLHS = I->getOperand(0);
DOUT << "RA: " << *I;
I->setOperand(0, Ops[i].Op);
I->setOperand(1, Ops[i+1].Op);
DOUT << "TO: " << *I;
MadeChange = true;
++NumChanged;
RemoveDeadBinaryOp(OldLHS);
}
return;
}
assert(i+2 < Ops.size() && "Ops index out of range!");
if (I->getOperand(1) != Ops[i].Op) {
DOUT << "RA: " << *I;
I->setOperand(1, Ops[i].Op);
DOUT << "TO: " << *I;
MadeChange = true;
++NumChanged;
}
BinaryOperator *LHS = cast<BinaryOperator>(I->getOperand(0));
assert(LHS->getOpcode() == I->getOpcode() &&
"Improper expression tree!");
LHS->moveBefore(I);
RewriteExprTree(LHS, Ops, i+1);
}
static Value *NegateValue(Value *V, Instruction *BI) {
if (Instruction *I = dyn_cast<Instruction>(V))
if (I->getOpcode() == Instruction::Add && I->hasOneUse()) {
I->setOperand(0, NegateValue(I->getOperand(0), BI));
I->setOperand(1, NegateValue(I->getOperand(1), BI));
I->moveBefore(BI);
I->setName(I->getName()+".neg");
return I;
}
return BinaryOperator::CreateNeg(V, V->getName() + ".neg", BI);
}
static bool ShouldBreakUpSubtract(Instruction *Sub) {
if (BinaryOperator::isNeg(Sub))
return false;
if (isReassociableOp(Sub->getOperand(0), Instruction::Add) ||
isReassociableOp(Sub->getOperand(0), Instruction::Sub))
return true;
if (isReassociableOp(Sub->getOperand(1), Instruction::Add) ||
isReassociableOp(Sub->getOperand(1), Instruction::Sub))
return true;
if (Sub->hasOneUse() &&
(isReassociableOp(Sub->use_back(), Instruction::Add) ||
isReassociableOp(Sub->use_back(), Instruction::Sub)))
return true;
return false;
}
static Instruction *BreakUpSubtract(Instruction *Sub,
std::map<AssertingVH<>, unsigned> &ValueRankMap) {
Value *NegVal = NegateValue(Sub->getOperand(1), Sub);
Instruction *New =
BinaryOperator::CreateAdd(Sub->getOperand(0), NegVal, "", Sub);
New->takeName(Sub);
ValueRankMap.erase(Sub);
Sub->replaceAllUsesWith(New);
Sub->eraseFromParent();
DOUT << "Negated: " << *New;
return New;
}
static Instruction *ConvertShiftToMul(Instruction *Shl,
std::map<AssertingVH<>, unsigned> &ValueRankMap) {
if (isReassociableOp(Shl->getOperand(0), Instruction::Mul) ||
(Shl->hasOneUse() &&
(isReassociableOp(Shl->use_back(), Instruction::Mul) ||
isReassociableOp(Shl->use_back(), Instruction::Add)))) {
Constant *MulCst = ConstantInt::get(Shl->getType(), 1);
MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1)));
Instruction *Mul = BinaryOperator::CreateMul(Shl->getOperand(0), MulCst,
"", Shl);
ValueRankMap.erase(Shl);
Mul->takeName(Shl);
Shl->replaceAllUsesWith(Mul);
Shl->eraseFromParent();
return Mul;
}
return 0;
}
static unsigned FindInOperandList(std::vector<ValueEntry> &Ops, unsigned i,
Value *X) {
unsigned XRank = Ops[i].Rank;
unsigned e = Ops.size();
for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j)
if (Ops[j].Op == X)
return j;
for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j)
if (Ops[j].Op == X)
return j;
return i;
}
static Value *EmitAddTreeOfValues(Instruction *I, std::vector<Value*> &Ops) {
if (Ops.size() == 1) return Ops.back();
Value *V1 = Ops.back();
Ops.pop_back();
Value *V2 = EmitAddTreeOfValues(I, Ops);
return BinaryOperator::CreateAdd(V2, V1, "tmp", I);
}
Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul);
if (!BO) return 0;
std::vector<ValueEntry> Factors;
LinearizeExprTree(BO, Factors);
bool FoundFactor = false;
for (unsigned i = 0, e = Factors.size(); i != e; ++i)
if (Factors[i].Op == Factor) {
FoundFactor = true;
Factors.erase(Factors.begin()+i);
break;
}
if (!FoundFactor) {
RewriteExprTree(BO, Factors);
return 0;
}
if (Factors.size() == 1) return Factors[0].Op;
RewriteExprTree(BO, Factors);
return BO;
}
static void FindSingleUseMultiplyFactors(Value *V,
std::vector<Value*> &Factors) {
BinaryOperator *BO;
if ((!V->hasOneUse() && !V->use_empty()) ||
!(BO = dyn_cast<BinaryOperator>(V)) ||
BO->getOpcode() != Instruction::Mul) {
Factors.push_back(V);
return;
}
FindSingleUseMultiplyFactors(BO->getOperand(1), Factors);
FindSingleUseMultiplyFactors(BO->getOperand(0), Factors);
}
Value *Reassociate::OptimizeExpression(BinaryOperator *I,
std::vector<ValueEntry> &Ops) {
bool IterateOptimization = false;
if (Ops.size() == 1) return Ops[0].Op;
unsigned Opcode = I->getOpcode();
if (Constant *V1 = dyn_cast<Constant>(Ops[Ops.size()-2].Op))
if (Constant *V2 = dyn_cast<Constant>(Ops.back().Op)) {
Ops.pop_back();
Ops.back().Op = ConstantExpr::get(Opcode, V1, V2);
return OptimizeExpression(I, Ops);
}
if (ConstantInt *CstVal = dyn_cast<ConstantInt>(Ops.back().Op))
switch (Opcode) {
default: break;
case Instruction::And:
if (CstVal->isZero()) { ++NumAnnihil;
return CstVal;
} else if (CstVal->isAllOnesValue()) { Ops.pop_back();
}
break;
case Instruction::Mul:
if (CstVal->isZero()) { ++NumAnnihil;
return CstVal;
} else if (cast<ConstantInt>(CstVal)->isOne()) {
Ops.pop_back(); }
break;
case Instruction::Or:
if (CstVal->isAllOnesValue()) { ++NumAnnihil;
return CstVal;
}
case Instruction::Add:
case Instruction::Xor:
if (CstVal->isZero()) Ops.pop_back();
break;
}
if (Ops.size() == 1) return Ops[0].Op;
switch (Opcode) {
default: break;
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
assert(i < Ops.size());
if (BinaryOperator::isNot(Ops[i].Op)) { Value *X = BinaryOperator::getNotArgument(Ops[i].Op);
unsigned FoundX = FindInOperandList(Ops, i, X);
if (FoundX != i) {
if (Opcode == Instruction::And) { ++NumAnnihil;
return Constant::getNullValue(X->getType());
} else if (Opcode == Instruction::Or) { ++NumAnnihil;
return ConstantInt::getAllOnesValue(X->getType());
}
}
}
assert(i < Ops.size());
if (i+1 != Ops.size() && Ops[i+1].Op == Ops[i].Op) {
if (Opcode == Instruction::And || Opcode == Instruction::Or) {
Ops.erase(Ops.begin()+i);
--i; --e;
IterateOptimization = true;
++NumAnnihil;
} else {
assert(Opcode == Instruction::Xor);
if (e == 2) {
++NumAnnihil;
return Constant::getNullValue(Ops[0].Op->getType());
}
Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
i -= 1; e -= 2;
IterateOptimization = true;
++NumAnnihil;
}
}
}
break;
case Instruction::Add:
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
assert(i < Ops.size());
if (BinaryOperator::isNeg(Ops[i].Op)) {
Value *X = BinaryOperator::getNegArgument(Ops[i].Op);
unsigned FoundX = FindInOperandList(Ops, i, X);
if (FoundX != i) {
if (Ops.size() == 2) {
++NumAnnihil;
return Constant::getNullValue(X->getType());
} else {
Ops.erase(Ops.begin()+i);
if (i < FoundX)
--FoundX;
else
--i; Ops.erase(Ops.begin()+FoundX);
IterateOptimization = true;
++NumAnnihil;
--i; e -= 2; }
}
}
}
std::map<Value*, unsigned> FactorOccurrences;
unsigned MaxOcc = 0;
Value *MaxOccVal = 0;
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(Ops[i].Op)) {
if (BOp->getOpcode() == Instruction::Mul && BOp->use_empty()) {
std::vector<Value*> Factors;
FindSingleUseMultiplyFactors(BOp, Factors);
assert(Factors.size() > 1 && "Bad linearize!");
if (Factors.size() == 2) {
unsigned Occ = ++FactorOccurrences[Factors[0]];
if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[0]; }
if (Factors[0] != Factors[1]) { Occ = ++FactorOccurrences[Factors[1]];
if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[1]; }
}
} else {
std::set<Value*> Duplicates;
for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
if (Duplicates.insert(Factors[i]).second) {
unsigned Occ = ++FactorOccurrences[Factors[i]];
if (Occ > MaxOcc) { MaxOcc = Occ; MaxOccVal = Factors[i]; }
}
}
}
}
}
}
if (MaxOcc > 1) {
DOUT << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << "\n";
Instruction *DummyInst = BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal);
std::vector<Value*> NewMulOps;
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) {
NewMulOps.push_back(V);
Ops.erase(Ops.begin()+i);
--i; --e;
}
}
delete DummyInst;
unsigned NumAddedValues = NewMulOps.size();
Value *V = EmitAddTreeOfValues(I, NewMulOps);
Value *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I);
if (NumAddedValues > 1)
ReassociateExpression(cast<BinaryOperator>(V));
++NumFactor;
if (Ops.empty())
return V2;
Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2));
RewriteExprTree(I, Ops);
return OptimizeExpression(I, Ops);
}
break;
}
if (IterateOptimization)
return OptimizeExpression(I, Ops);
return 0;
}
void Reassociate::ReassociateBB(BasicBlock *BB) {
for (BasicBlock::iterator BBI = BB->begin(); BBI != BB->end(); ) {
Instruction *BI = BBI++;
if (BI->getOpcode() == Instruction::Shl &&
isa<ConstantInt>(BI->getOperand(1)))
if (Instruction *NI = ConvertShiftToMul(BI, ValueRankMap)) {
MadeChange = true;
BI = NI;
}
if (!isa<BinaryOperator>(BI) || BI->getType()->isFloatingPoint() ||
isa<VectorType>(BI->getType()))
continue;
if (BI->getOpcode() == Instruction::Sub) {
if (ShouldBreakUpSubtract(BI)) {
BI = BreakUpSubtract(BI, ValueRankMap);
MadeChange = true;
} else if (BinaryOperator::isNeg(BI)) {
if (isReassociableOp(BI->getOperand(1), Instruction::Mul) &&
(!BI->hasOneUse() ||
!isReassociableOp(BI->use_back(), Instruction::Mul))) {
BI = LowerNegateToMultiply(BI, ValueRankMap);
MadeChange = true;
}
}
}
if (!BI->isAssociative()) continue;
BinaryOperator *I = cast<BinaryOperator>(BI);
if (I->hasOneUse() && isReassociableOp(I->use_back(), I->getOpcode()))
continue;
if (I->hasOneUse() && I->getOpcode() == Instruction::Add &&
cast<Instruction>(I->use_back())->getOpcode() == Instruction::Sub)
continue;
ReassociateExpression(I);
}
}
void Reassociate::ReassociateExpression(BinaryOperator *I) {
std::vector<ValueEntry> Ops;
LinearizeExprTree(I, Ops);
DOUT << "RAIn:\t"; DEBUG(PrintOps(I, Ops)); DOUT << "\n";
std::stable_sort(Ops.begin(), Ops.end());
if (Value *V = OptimizeExpression(I, Ops)) {
DOUT << "Reassoc to scalar: " << *V << "\n";
I->replaceAllUsesWith(V);
RemoveDeadBinaryOp(I);
return;
}
if (I->getOpcode() == Instruction::Mul && I->hasOneUse() &&
cast<Instruction>(I->use_back())->getOpcode() == Instruction::Add &&
isa<ConstantInt>(Ops.back().Op) &&
cast<ConstantInt>(Ops.back().Op)->isAllOnesValue()) {
Ops.insert(Ops.begin(), Ops.back());
Ops.pop_back();
}
DOUT << "RAOut:\t"; DEBUG(PrintOps(I, Ops)); DOUT << "\n";
if (Ops.size() == 1) {
I->replaceAllUsesWith(Ops[0].Op);
RemoveDeadBinaryOp(I);
} else {
RewriteExprTree(I, Ops);
}
}
bool Reassociate::runOnFunction(Function &F) {
BuildRankMap(F);
MadeChange = false;
for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI)
ReassociateBB(FI);
RankMap.clear();
ValueRankMap.clear();
return MadeChange;
}