#include "llvm/Transforms/Scalar.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/CFG.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/ValueHandle.h"
#include "llvm/Pass.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
using namespace llvm;
#define DEBUG_TYPE "reassociate"
STATISTIC(NumChanged, "Number of insts reassociated");
STATISTIC(NumAnnihil, "Number of expr tree annihilated");
STATISTIC(NumFactor , "Number of multiplies factored");
namespace {
struct ValueEntry {
unsigned Rank;
Value *Op;
ValueEntry(unsigned R, Value *O) : Rank(R), Op(O) {}
};
inline bool operator<(const ValueEntry &LHS, const ValueEntry &RHS) {
return LHS.Rank > RHS.Rank; }
}
#ifndef NDEBUG
static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) {
Module *M = I->getParent()->getParent()->getParent();
dbgs() << Instruction::getOpcodeName(I->getOpcode()) << " "
<< *Ops[0].Op->getType() << '\t';
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
dbgs() << "[ ";
Ops[i].Op->printAsOperand(dbgs(), false, M);
dbgs() << ", #" << Ops[i].Rank << "] ";
}
}
#endif
namespace {
struct Factor {
Value *Base;
unsigned Power;
Factor(Value *Base, unsigned Power) : Base(Base), Power(Power) {}
struct BaseSorter {
bool operator()(const Factor &LHS, const Factor &RHS) {
return LHS.Base < RHS.Base;
}
};
struct BaseEqual {
bool operator()(const Factor &LHS, const Factor &RHS) {
return LHS.Base == RHS.Base;
}
};
struct PowerDescendingSorter {
bool operator()(const Factor &LHS, const Factor &RHS) {
return LHS.Power > RHS.Power;
}
};
struct PowerEqual {
bool operator()(const Factor &LHS, const Factor &RHS) {
return LHS.Power == RHS.Power;
}
};
};
class XorOpnd {
public:
XorOpnd(Value *V);
bool isInvalid() const { return SymbolicPart == nullptr; }
bool isOrExpr() const { return isOr; }
Value *getValue() const { return OrigVal; }
Value *getSymbolicPart() const { return SymbolicPart; }
unsigned getSymbolicRank() const { return SymbolicRank; }
const APInt &getConstPart() const { return ConstPart; }
void Invalidate() { SymbolicPart = OrigVal = nullptr; }
void setSymbolicRank(unsigned R) { SymbolicRank = R; }
struct PtrSortFunctor {
bool operator()(XorOpnd * const &LHS, XorOpnd * const &RHS) {
return LHS->getSymbolicRank() < RHS->getSymbolicRank();
}
};
private:
Value *OrigVal;
Value *SymbolicPart;
APInt ConstPart;
unsigned SymbolicRank;
bool isOr;
};
}
namespace {
class Reassociate : public FunctionPass {
DenseMap<BasicBlock*, unsigned> RankMap;
DenseMap<AssertingVH<Value>, unsigned> ValueRankMap;
SetVector<AssertingVH<Instruction> > RedoInsts;
bool MadeChange;
public:
static char ID; Reassociate() : FunctionPass(ID) {
initializeReassociatePass(*PassRegistry::getPassRegistry());
}
bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
}
private:
void BuildRankMap(Function &F);
unsigned getRank(Value *V);
void ReassociateExpression(BinaryOperator *I);
void RewriteExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
Value *OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops);
Value *OptimizeAdd(Instruction *I, SmallVectorImpl<ValueEntry> &Ops);
Value *OptimizeXor(Instruction *I, SmallVectorImpl<ValueEntry> &Ops);
bool CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt &ConstOpnd,
Value *&Res);
bool CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2,
APInt &ConstOpnd, Value *&Res);
bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
SmallVectorImpl<Factor> &Factors);
Value *buildMinimalMultiplyDAG(IRBuilder<> &Builder,
SmallVectorImpl<Factor> &Factors);
Value *OptimizeMul(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops);
Value *RemoveFactorFromExpression(Value *V, Value *Factor);
void EraseInst(Instruction *I);
void OptimizeInst(Instruction *I);
};
}
XorOpnd::XorOpnd(Value *V) {
assert(!isa<ConstantInt>(V) && "No ConstantInt");
OrigVal = V;
Instruction *I = dyn_cast<Instruction>(V);
SymbolicRank = 0;
if (I && (I->getOpcode() == Instruction::Or ||
I->getOpcode() == Instruction::And)) {
Value *V0 = I->getOperand(0);
Value *V1 = I->getOperand(1);
if (isa<ConstantInt>(V0))
std::swap(V0, V1);
if (ConstantInt *C = dyn_cast<ConstantInt>(V1)) {
ConstPart = C->getValue();
SymbolicPart = V0;
isOr = (I->getOpcode() == Instruction::Or);
return;
}
}
SymbolicPart = V;
ConstPart = APInt::getNullValue(V->getType()->getIntegerBitWidth());
isOr = true;
}
char Reassociate::ID = 0;
INITIALIZE_PASS(Reassociate, "reassociate",
"Reassociate expressions", false, false)
FunctionPass *llvm::createReassociatePass() { return new Reassociate(); }
static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) {
if (V->hasOneUse() && isa<Instruction>(V) &&
cast<Instruction>(V)->getOpcode() == Opcode &&
(!isa<FPMathOperator>(V) ||
cast<Instruction>(V)->hasUnsafeAlgebra()))
return cast<BinaryOperator>(V);
return nullptr;
}
static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1,
unsigned Opcode2) {
if (V->hasOneUse() && isa<Instruction>(V) &&
(cast<Instruction>(V)->getOpcode() == Opcode1 ||
cast<Instruction>(V)->getOpcode() == Opcode2) &&
(!isa<FPMathOperator>(V) ||
cast<Instruction>(V)->hasUnsafeAlgebra()))
return cast<BinaryOperator>(V);
return nullptr;
}
static bool isUnmovableInstruction(Instruction *I) {
switch (I->getOpcode()) {
case Instruction::PHI:
case Instruction::LandingPad:
case Instruction::Alloca:
case Instruction::Load:
case Instruction::Invoke:
case Instruction::UDiv:
case Instruction::SDiv:
case Instruction::FDiv:
case Instruction::URem:
case Instruction::SRem:
case Instruction::FRem:
return true;
case Instruction::Call:
return !isa<DbgInfoIntrinsic>(I);
default:
return false;
}
}
void Reassociate::BuildRankMap(Function &F) {
unsigned i = 2;
for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I)
ValueRankMap[&*I] = ++i;
ReversePostOrderTraversal<Function*> RPOT(&F);
for (ReversePostOrderTraversal<Function*>::rpo_iterator I = RPOT.begin(),
E = RPOT.end(); I != E; ++I) {
BasicBlock *BB = *I;
unsigned BBRank = RankMap[BB] = ++i << 16;
for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I)
if (isUnmovableInstruction(I))
ValueRankMap[&*I] = ++BBRank;
}
}
unsigned Reassociate::getRank(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
if (isa<Argument>(V)) return ValueRankMap[V]; return 0; }
if (unsigned Rank = ValueRankMap[I])
return Rank;
unsigned Rank = 0, MaxRank = RankMap[I->getParent()];
for (unsigned i = 0, e = I->getNumOperands();
i != e && Rank != MaxRank; ++i)
Rank = std::max(Rank, getRank(I->getOperand(i)));
Type *Ty = V->getType();
if ((!Ty->isIntegerTy() && !Ty->isFloatingPointTy()) ||
(!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I) &&
!BinaryOperator::isFNeg(I)))
++Rank;
return ValueRankMap[I] = Rank;
}
static BinaryOperator *CreateAdd(Value *S1, Value *S2, const Twine &Name,
Instruction *InsertBefore, Value *FlagsOp) {
if (S1->getType()->isIntegerTy())
return BinaryOperator::CreateAdd(S1, S2, Name, InsertBefore);
else {
BinaryOperator *Res =
BinaryOperator::CreateFAdd(S1, S2, Name, InsertBefore);
Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags());
return Res;
}
}
static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name,
Instruction *InsertBefore, Value *FlagsOp) {
if (S1->getType()->isIntegerTy())
return BinaryOperator::CreateMul(S1, S2, Name, InsertBefore);
else {
BinaryOperator *Res =
BinaryOperator::CreateFMul(S1, S2, Name, InsertBefore);
Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags());
return Res;
}
}
static BinaryOperator *CreateNeg(Value *S1, const Twine &Name,
Instruction *InsertBefore, Value *FlagsOp) {
if (S1->getType()->isIntegerTy())
return BinaryOperator::CreateNeg(S1, Name, InsertBefore);
else {
BinaryOperator *Res = BinaryOperator::CreateFNeg(S1, Name, InsertBefore);
Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags());
return Res;
}
}
static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) {
Type *Ty = Neg->getType();
Constant *NegOne = Ty->isIntegerTy() ? ConstantInt::getAllOnesValue(Ty)
: ConstantFP::get(Ty, -1.0);
BinaryOperator *Res = CreateMul(Neg->getOperand(1), NegOne, "", Neg, Neg);
Neg->setOperand(1, Constant::getNullValue(Ty)); Res->takeName(Neg);
Neg->replaceAllUsesWith(Res);
Res->setDebugLoc(Neg->getDebugLoc());
return Res;
}
static unsigned CarmichaelShift(unsigned Bitwidth) {
if (Bitwidth < 3)
return Bitwidth - 1;
return Bitwidth - 2;
}
static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) {
if (RHS.isMinValue())
return;
if (LHS.isMinValue()) {
LHS = RHS;
return;
}
if (Instruction::isIdempotent(Opcode)) {
assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
return; }
if (Instruction::isNilpotent(Opcode)) {
assert(LHS == 1 && RHS == 1 && "Weights not reduced!");
LHS = 0; return;
}
if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) {
LHS += RHS;
return;
}
assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) &&
"Unknown associative operation!");
unsigned Bitwidth = LHS.getBitWidth();
if (Bitwidth > 3) {
APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth));
APInt Threshold = CM + Bitwidth;
assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!");
LHS += RHS;
while (LHS.uge(Threshold))
LHS -= CM;
} else {
unsigned CM = 1U << CarmichaelShift(Bitwidth);
unsigned Threshold = CM + Bitwidth;
assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold &&
"Weights not reduced!");
unsigned Total = LHS.getZExtValue() + RHS.getZExtValue();
while (Total >= Threshold)
Total -= CM;
LHS = Total;
}
}
typedef std::pair<Value*, APInt> RepeatedValue;
static bool LinearizeExprTree(BinaryOperator *I,
SmallVectorImpl<RepeatedValue> &Ops) {
DEBUG(dbgs() << "LINEARIZE: " << *I << '\n');
unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits();
unsigned Opcode = I->getOpcode();
assert(I->isAssociative() && I->isCommutative() &&
"Expected an associative and commutative operation!");
SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1)));
bool MadeChange = false;
typedef DenseMap<Value*, APInt> LeafMap;
LeafMap Leaves; SmallVector<Value*, 8> LeafOrder;
#ifndef NDEBUG
SmallPtrSet<Value*, 8> Visited; #endif
while (!Worklist.empty()) {
std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val();
I = P.first;
for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { Value *Op = I->getOperand(OpIdx);
APInt Weight = P.second; DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n");
assert(!Op->use_empty() && "No uses, so how did we get to it?!");
if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) {
assert(Visited.insert(Op) && "Not first visit!");
DEBUG(dbgs() << "DIRECT ADD: " << *Op << " (" << Weight << ")\n");
Worklist.push_back(std::make_pair(BO, Weight));
continue;
}
LeafMap::iterator It = Leaves.find(Op);
if (It == Leaves.end()) {
assert(Visited.insert(Op) && "Not first visit!");
if (!Op->hasOneUse()) {
DEBUG(dbgs() << "ADD USES LEAF: " << *Op << " (" << Weight << ")\n");
LeafOrder.push_back(Op);
Leaves[Op] = Weight;
continue;
}
} else if (It != Leaves.end()) {
assert(Visited.count(Op) && "In leaf map but not visited!");
IncorporateWeight(It->second, Weight, Opcode);
#if 0 // TODO: Re-enable once PR13021 is fixed.
assert(!Op->hasOneUse() && "Only one use, but we got here twice!");
I->setOperand(OpIdx, UndefValue::get(I->getType()));
MadeChange = true;
if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) {
DEBUG(dbgs() << "UNLEAF: " << *Op << " (" << It->second << ")\n");
Worklist.push_back(std::make_pair(BO, It->second));
Leaves.erase(It);
continue;
}
#endif
if (!Op->hasOneUse())
continue;
Weight = It->second;
Leaves.erase(It); }
assert((!isa<Instruction>(Op) ||
cast<Instruction>(Op)->getOpcode() != Opcode
|| (isa<FPMathOperator>(Op) &&
!cast<Instruction>(Op)->hasUnsafeAlgebra())) &&
"Should have been handled above!");
assert(Op->hasOneUse() && "Has uses outside the expression tree!");
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op))
if ((Opcode == Instruction::Mul && BinaryOperator::isNeg(BO)) ||
(Opcode == Instruction::FMul && BinaryOperator::isFNeg(BO))) {
DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO ");
BO = LowerNegateToMultiply(BO);
DEBUG(dbgs() << *BO << '\n');
Worklist.push_back(std::make_pair(BO, Weight));
MadeChange = true;
continue;
}
DEBUG(dbgs() << "ADD LEAF: " << *Op << " (" << Weight << ")\n");
assert(!isReassociableOp(Op, Opcode) && "Value was morphed?");
LeafOrder.push_back(Op);
Leaves[Op] = Weight;
}
}
for (unsigned i = 0, e = LeafOrder.size(); i != e; ++i) {
Value *V = LeafOrder[i];
LeafMap::iterator It = Leaves.find(V);
if (It == Leaves.end())
continue;
assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!");
APInt Weight = It->second;
if (Weight.isMinValue())
continue;
It->second = 0;
Ops.push_back(std::make_pair(V, Weight));
}
if (Ops.empty()) {
Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType());
assert(Identity && "Associative operation without identity!");
Ops.push_back(std::make_pair(Identity, APInt(Bitwidth, 1)));
}
return MadeChange;
}
void Reassociate::RewriteExprTree(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops) {
assert(Ops.size() > 1 && "Single values should be used directly!");
SmallVector<BinaryOperator*, 8> NodesToRewrite;
unsigned Opcode = I->getOpcode();
BinaryOperator *Op = I;
SmallPtrSet<Value*, 8> NotRewritable;
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
NotRewritable.insert(Ops[i].Op);
BinaryOperator *ExpressionChanged = nullptr;
for (unsigned i = 0; ; ++i) {
if (i+2 == Ops.size()) {
Value *NewLHS = Ops[i].Op;
Value *NewRHS = Ops[i+1].Op;
Value *OldLHS = Op->getOperand(0);
Value *OldRHS = Op->getOperand(1);
if (NewLHS == OldLHS && NewRHS == OldRHS)
break;
if (NewLHS == OldRHS && NewRHS == OldLHS) {
DEBUG(dbgs() << "RA: " << *Op << '\n');
Op->swapOperands();
DEBUG(dbgs() << "TO: " << *Op << '\n');
MadeChange = true;
++NumChanged;
break;
}
DEBUG(dbgs() << "RA: " << *Op << '\n');
if (NewLHS != OldLHS) {
BinaryOperator *BO = isReassociableOp(OldLHS, Opcode);
if (BO && !NotRewritable.count(BO))
NodesToRewrite.push_back(BO);
Op->setOperand(0, NewLHS);
}
if (NewRHS != OldRHS) {
BinaryOperator *BO = isReassociableOp(OldRHS, Opcode);
if (BO && !NotRewritable.count(BO))
NodesToRewrite.push_back(BO);
Op->setOperand(1, NewRHS);
}
DEBUG(dbgs() << "TO: " << *Op << '\n');
ExpressionChanged = Op;
MadeChange = true;
++NumChanged;
break;
}
Value *NewRHS = Ops[i].Op;
if (NewRHS != Op->getOperand(1)) {
DEBUG(dbgs() << "RA: " << *Op << '\n');
if (NewRHS == Op->getOperand(0)) {
Op->swapOperands();
} else {
BinaryOperator *BO = isReassociableOp(Op->getOperand(1), Opcode);
if (BO && !NotRewritable.count(BO))
NodesToRewrite.push_back(BO);
Op->setOperand(1, NewRHS);
ExpressionChanged = Op;
}
DEBUG(dbgs() << "TO: " << *Op << '\n');
MadeChange = true;
++NumChanged;
}
BinaryOperator *BO = isReassociableOp(Op->getOperand(0), Opcode);
if (BO && !NotRewritable.count(BO)) {
Op = BO;
continue;
}
BinaryOperator *NewOp;
if (NodesToRewrite.empty()) {
Constant *Undef = UndefValue::get(I->getType());
NewOp = BinaryOperator::Create(Instruction::BinaryOps(Opcode),
Undef, Undef, "", I);
if (NewOp->getType()->isFloatingPointTy())
NewOp->setFastMathFlags(I->getFastMathFlags());
} else {
NewOp = NodesToRewrite.pop_back_val();
}
DEBUG(dbgs() << "RA: " << *Op << '\n');
Op->setOperand(0, NewOp);
DEBUG(dbgs() << "TO: " << *Op << '\n');
ExpressionChanged = Op;
MadeChange = true;
++NumChanged;
Op = NewOp;
}
if (ExpressionChanged)
do {
if (isa<FPMathOperator>(I)) {
FastMathFlags Flags = I->getFastMathFlags();
ExpressionChanged->clearSubclassOptionalData();
ExpressionChanged->setFastMathFlags(Flags);
} else
ExpressionChanged->clearSubclassOptionalData();
if (ExpressionChanged == I)
break;
ExpressionChanged->moveBefore(I);
ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin());
} while (1);
for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i)
RedoInsts.insert(NodesToRewrite[i]);
}
static Value *NegateValue(Value *V, Instruction *BI) {
if (ConstantFP *C = dyn_cast<ConstantFP>(V))
return ConstantExpr::getFNeg(C);
if (Constant *C = dyn_cast<Constant>(V))
return ConstantExpr::getNeg(C);
if (BinaryOperator *I =
isReassociableOp(V, Instruction::Add, Instruction::FAdd)) {
I->setOperand(0, NegateValue(I->getOperand(0), BI));
I->setOperand(1, NegateValue(I->getOperand(1), BI));
I->moveBefore(BI);
I->setName(I->getName()+".neg");
return I;
}
for (User *U : V->users()) {
if (!BinaryOperator::isNeg(U) && !BinaryOperator::isFNeg(U))
continue;
BinaryOperator *TheNeg = cast<BinaryOperator>(U);
if (TheNeg->getParent()->getParent() != BI->getParent()->getParent())
continue;
BasicBlock::iterator InsertPt;
if (Instruction *InstInput = dyn_cast<Instruction>(V)) {
if (InvokeInst *II = dyn_cast<InvokeInst>(InstInput)) {
InsertPt = II->getNormalDest()->begin();
} else {
InsertPt = InstInput;
++InsertPt;
}
while (isa<PHINode>(InsertPt)) ++InsertPt;
} else {
InsertPt = TheNeg->getParent()->getParent()->getEntryBlock().begin();
}
TheNeg->moveBefore(InsertPt);
return TheNeg;
}
return CreateNeg(V, V->getName() + ".neg", BI, BI);
}
static bool ShouldBreakUpSubtract(Instruction *Sub) {
if (BinaryOperator::isNeg(Sub) || BinaryOperator::isFNeg(Sub))
return false;
if (isa<UndefValue>(Sub->getOperand(1)))
return false;
Value *V0 = Sub->getOperand(0);
if (isReassociableOp(V0, Instruction::Add, Instruction::FAdd) ||
isReassociableOp(V0, Instruction::Sub, Instruction::FSub))
return true;
Value *V1 = Sub->getOperand(1);
if (isReassociableOp(V1, Instruction::Add, Instruction::FAdd) ||
isReassociableOp(V1, Instruction::Sub, Instruction::FSub))
return true;
Value *VB = Sub->user_back();
if (Sub->hasOneUse() &&
(isReassociableOp(VB, Instruction::Add, Instruction::FAdd) ||
isReassociableOp(VB, Instruction::Sub, Instruction::FSub)))
return true;
return false;
}
static BinaryOperator *BreakUpSubtract(Instruction *Sub) {
Value *NegVal = NegateValue(Sub->getOperand(1), Sub);
BinaryOperator *New = CreateAdd(Sub->getOperand(0), NegVal, "", Sub, Sub);
Sub->setOperand(0, Constant::getNullValue(Sub->getType())); Sub->setOperand(1, Constant::getNullValue(Sub->getType())); New->takeName(Sub);
Sub->replaceAllUsesWith(New);
New->setDebugLoc(Sub->getDebugLoc());
DEBUG(dbgs() << "Negated: " << *New << '\n');
return New;
}
static BinaryOperator *ConvertShiftToMul(Instruction *Shl) {
Constant *MulCst = ConstantInt::get(Shl->getType(), 1);
MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1)));
BinaryOperator *Mul =
BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl);
Shl->setOperand(0, UndefValue::get(Shl->getType())); Mul->takeName(Shl);
Shl->replaceAllUsesWith(Mul);
Mul->setDebugLoc(Shl->getDebugLoc());
return Mul;
}
static unsigned FindInOperandList(SmallVectorImpl<ValueEntry> &Ops, unsigned i,
Value *X) {
unsigned XRank = Ops[i].Rank;
unsigned e = Ops.size();
for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j)
if (Ops[j].Op == X)
return j;
for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j)
if (Ops[j].Op == X)
return j;
return i;
}
static Value *EmitAddTreeOfValues(Instruction *I,
SmallVectorImpl<WeakVH> &Ops){
if (Ops.size() == 1) return Ops.back();
Value *V1 = Ops.back();
Ops.pop_back();
Value *V2 = EmitAddTreeOfValues(I, Ops);
return CreateAdd(V2, V1, "tmp", I, I);
}
Value *Reassociate::RemoveFactorFromExpression(Value *V, Value *Factor) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
if (!BO)
return nullptr;
SmallVector<RepeatedValue, 8> Tree;
MadeChange |= LinearizeExprTree(BO, Tree);
SmallVector<ValueEntry, 8> Factors;
Factors.reserve(Tree.size());
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
RepeatedValue E = Tree[i];
Factors.append(E.second.getZExtValue(),
ValueEntry(getRank(E.first), E.first));
}
bool FoundFactor = false;
bool NeedsNegate = false;
for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
if (Factors[i].Op == Factor) {
FoundFactor = true;
Factors.erase(Factors.begin()+i);
break;
}
if (ConstantInt *FC1 = dyn_cast<ConstantInt>(Factor)) {
if (ConstantInt *FC2 = dyn_cast<ConstantInt>(Factors[i].Op))
if (FC1->getValue() == -FC2->getValue()) {
FoundFactor = NeedsNegate = true;
Factors.erase(Factors.begin()+i);
break;
}
} else if (ConstantFP *FC1 = dyn_cast<ConstantFP>(Factor)) {
if (ConstantFP *FC2 = dyn_cast<ConstantFP>(Factors[i].Op)) {
APFloat F1(FC1->getValueAPF());
APFloat F2(FC2->getValueAPF());
F2.changeSign();
if (F1.compare(F2) == APFloat::cmpEqual) {
FoundFactor = NeedsNegate = true;
Factors.erase(Factors.begin() + i);
break;
}
}
}
}
if (!FoundFactor) {
RewriteExprTree(BO, Factors);
return nullptr;
}
BasicBlock::iterator InsertPt = BO; ++InsertPt;
if (Factors.size() == 1) {
RedoInsts.insert(BO);
V = Factors[0].Op;
} else {
RewriteExprTree(BO, Factors);
V = BO;
}
if (NeedsNegate)
V = CreateNeg(V, "neg", InsertPt, BO);
return V;
}
static void FindSingleUseMultiplyFactors(Value *V,
SmallVectorImpl<Value*> &Factors,
const SmallVectorImpl<ValueEntry> &Ops) {
BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul);
if (!BO) {
Factors.push_back(V);
return;
}
FindSingleUseMultiplyFactors(BO->getOperand(1), Factors, Ops);
FindSingleUseMultiplyFactors(BO->getOperand(0), Factors, Ops);
}
static Value *OptimizeAndOrXor(unsigned Opcode,
SmallVectorImpl<ValueEntry> &Ops) {
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
assert(i < Ops.size());
if (BinaryOperator::isNot(Ops[i].Op)) { Value *X = BinaryOperator::getNotArgument(Ops[i].Op);
unsigned FoundX = FindInOperandList(Ops, i, X);
if (FoundX != i) {
if (Opcode == Instruction::And) return Constant::getNullValue(X->getType());
if (Opcode == Instruction::Or) return Constant::getAllOnesValue(X->getType());
}
}
assert(i < Ops.size());
if (i+1 != Ops.size() && Ops[i+1].Op == Ops[i].Op) {
if (Opcode == Instruction::And || Opcode == Instruction::Or) {
Ops.erase(Ops.begin()+i);
--i; --e;
++NumAnnihil;
continue;
}
assert(Opcode == Instruction::Xor);
if (e == 2)
return Constant::getNullValue(Ops[0].Op->getType());
Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
i -= 1; e -= 2;
++NumAnnihil;
}
}
return nullptr;
}
static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd,
const APInt &ConstOpnd) {
if (ConstOpnd != 0) {
if (!ConstOpnd.isAllOnesValue()) {
LLVMContext &Ctx = Opnd->getType()->getContext();
Instruction *I;
I = BinaryOperator::CreateAnd(Opnd, ConstantInt::get(Ctx, ConstOpnd),
"and.ra", InsertBefore);
I->setDebugLoc(InsertBefore->getDebugLoc());
return I;
}
return Opnd;
}
return nullptr;
}
bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1,
APInt &ConstOpnd, Value *&Res) {
if (Opnd1->isOrExpr() && Opnd1->getConstPart() != 0) {
if (!Opnd1->getValue()->hasOneUse())
return false;
const APInt &C1 = Opnd1->getConstPart();
if (C1 != ConstOpnd)
return false;
Value *X = Opnd1->getSymbolicPart();
Res = createAndInstr(I, X, ~C1);
ConstOpnd ^= C1;
if (Instruction *T = dyn_cast<Instruction>(Opnd1->getValue()))
RedoInsts.insert(T);
return true;
}
return false;
}
bool Reassociate::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, XorOpnd *Opnd2,
APInt &ConstOpnd, Value *&Res) {
Value *X = Opnd1->getSymbolicPart();
if (X != Opnd2->getSymbolicPart())
return false;
int DeadInstNum = 1;
if (Opnd1->getValue()->hasOneUse())
DeadInstNum++;
if (Opnd2->getValue()->hasOneUse())
DeadInstNum++;
if (Opnd1->isOrExpr() != Opnd2->isOrExpr()) {
if (Opnd2->isOrExpr())
std::swap(Opnd1, Opnd2);
const APInt &C1 = Opnd1->getConstPart();
const APInt &C2 = Opnd2->getConstPart();
APInt C3((~C1) ^ C2);
if (C3 != 0 && !C3.isAllOnesValue()) {
int NewInstNum = ConstOpnd != 0 ? 1 : 2;
if (NewInstNum > DeadInstNum)
return false;
}
Res = createAndInstr(I, X, C3);
ConstOpnd ^= C1;
} else if (Opnd1->isOrExpr()) {
const APInt &C1 = Opnd1->getConstPart();
const APInt &C2 = Opnd2->getConstPart();
APInt C3 = C1 ^ C2;
if (C3 != 0 && !C3.isAllOnesValue()) {
int NewInstNum = ConstOpnd != 0 ? 1 : 2;
if (NewInstNum > DeadInstNum)
return false;
}
Res = createAndInstr(I, X, C3);
ConstOpnd ^= C3;
} else {
const APInt &C1 = Opnd1->getConstPart();
const APInt &C2 = Opnd2->getConstPart();
APInt C3 = C1 ^ C2;
Res = createAndInstr(I, X, C3);
}
if (Instruction *T = dyn_cast<Instruction>(Opnd1->getValue()))
RedoInsts.insert(T);
if (Instruction *T = dyn_cast<Instruction>(Opnd2->getValue()))
RedoInsts.insert(T);
return true;
}
Value *Reassociate::OptimizeXor(Instruction *I,
SmallVectorImpl<ValueEntry> &Ops) {
if (Value *V = OptimizeAndOrXor(Instruction::Xor, Ops))
return V;
if (Ops.size() == 1)
return nullptr;
SmallVector<XorOpnd, 8> Opnds;
SmallVector<XorOpnd*, 8> OpndPtrs;
Type *Ty = Ops[0].Op->getType();
APInt ConstOpnd(Ty->getIntegerBitWidth(), 0);
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
Value *V = Ops[i].Op;
if (!isa<ConstantInt>(V)) {
XorOpnd O(V);
O.setSymbolicRank(getRank(O.getSymbolicPart()));
Opnds.push_back(O);
} else
ConstOpnd ^= cast<ConstantInt>(V)->getValue();
}
for (unsigned i = 0, e = Opnds.size(); i != e; ++i)
OpndPtrs.push_back(&Opnds[i]);
std::stable_sort(OpndPtrs.begin(), OpndPtrs.end(), XorOpnd::PtrSortFunctor());
XorOpnd *PrevOpnd = nullptr;
bool Changed = false;
for (unsigned i = 0, e = Opnds.size(); i < e; i++) {
XorOpnd *CurrOpnd = OpndPtrs[i];
Value *CV;
if (ConstOpnd != 0 && CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) {
Changed = true;
if (CV)
*CurrOpnd = XorOpnd(CV);
else {
CurrOpnd->Invalidate();
continue;
}
}
if (!PrevOpnd || CurrOpnd->getSymbolicPart() != PrevOpnd->getSymbolicPart()) {
PrevOpnd = CurrOpnd;
continue;
}
if (CombineXorOpnd(I, CurrOpnd, PrevOpnd, ConstOpnd, CV)) {
PrevOpnd->Invalidate();
if (CV) {
*CurrOpnd = XorOpnd(CV);
PrevOpnd = CurrOpnd;
} else {
CurrOpnd->Invalidate();
PrevOpnd = nullptr;
}
Changed = true;
}
}
if (Changed) {
Ops.clear();
for (unsigned int i = 0, e = Opnds.size(); i < e; i++) {
XorOpnd &O = Opnds[i];
if (O.isInvalid())
continue;
ValueEntry VE(getRank(O.getValue()), O.getValue());
Ops.push_back(VE);
}
if (ConstOpnd != 0) {
Value *C = ConstantInt::get(Ty->getContext(), ConstOpnd);
ValueEntry VE(getRank(C), C);
Ops.push_back(VE);
}
int Sz = Ops.size();
if (Sz == 1)
return Ops.back().Op;
else if (Sz == 0) {
assert(ConstOpnd == 0);
return ConstantInt::get(Ty->getContext(), ConstOpnd);
}
}
return nullptr;
}
Value *Reassociate::OptimizeAdd(Instruction *I,
SmallVectorImpl<ValueEntry> &Ops) {
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
Value *TheOp = Ops[i].Op;
if (i+1 != Ops.size() && Ops[i+1].Op == TheOp) {
unsigned NumFound = 0;
do {
Ops.erase(Ops.begin()+i);
++NumFound;
} while (i != Ops.size() && Ops[i].Op == TheOp);
DEBUG(errs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n');
++NumFactor;
Type *Ty = TheOp->getType();
Constant *C = Ty->isIntegerTy() ? ConstantInt::get(Ty, NumFound)
: ConstantFP::get(Ty, NumFound);
Instruction *Mul = CreateMul(TheOp, C, "factor", I, I);
RedoInsts.insert(Mul);
if (Ops.empty())
return Mul;
Ops.insert(Ops.begin(), ValueEntry(getRank(Mul), Mul));
--i;
e = Ops.size();
continue;
}
if (!BinaryOperator::isNeg(TheOp) && !BinaryOperator::isFNeg(TheOp) &&
!BinaryOperator::isNot(TheOp))
continue;
Value *X = nullptr;
if (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp))
X = BinaryOperator::getNegArgument(TheOp);
else if (BinaryOperator::isNot(TheOp))
X = BinaryOperator::getNotArgument(TheOp);
unsigned FoundX = FindInOperandList(Ops, i, X);
if (FoundX == i)
continue;
if (Ops.size() == 2 &&
(BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp)))
return Constant::getNullValue(X->getType());
if (Ops.size() == 2 && BinaryOperator::isNot(TheOp))
return Constant::getAllOnesValue(X->getType());
Ops.erase(Ops.begin()+i);
if (i < FoundX)
--FoundX;
else
--i; Ops.erase(Ops.begin()+FoundX);
++NumAnnihil;
--i; e -= 2;
if (BinaryOperator::isNot(TheOp)) {
Value *V = Constant::getAllOnesValue(X->getType());
Ops.insert(Ops.end(), ValueEntry(getRank(V), V));
e += 1;
}
}
DenseMap<Value*, unsigned> FactorOccurrences;
unsigned MaxOcc = 0;
Value *MaxOccVal = nullptr;
for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
BinaryOperator *BOp =
isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul);
if (!BOp)
continue;
SmallVector<Value*, 8> Factors;
FindSingleUseMultiplyFactors(BOp, Factors, Ops);
assert(Factors.size() > 1 && "Bad linearize!");
SmallPtrSet<Value*, 8> Duplicates;
for (unsigned i = 0, e = Factors.size(); i != e; ++i) {
Value *Factor = Factors[i];
if (!Duplicates.insert(Factor))
continue;
unsigned Occ = ++FactorOccurrences[Factor];
if (Occ > MaxOcc) {
MaxOcc = Occ;
MaxOccVal = Factor;
}
if (ConstantInt *CI = dyn_cast<ConstantInt>(Factor)) {
if (CI->isNegative() && !CI->isMinValue(true)) {
Factor = ConstantInt::get(CI->getContext(), -CI->getValue());
assert(!Duplicates.count(Factor) &&
"Shouldn't have two constant factors, missed a canonicalize");
unsigned Occ = ++FactorOccurrences[Factor];
if (Occ > MaxOcc) {
MaxOcc = Occ;
MaxOccVal = Factor;
}
}
} else if (ConstantFP *CF = dyn_cast<ConstantFP>(Factor)) {
if (CF->isNegative()) {
APFloat F(CF->getValueAPF());
F.changeSign();
Factor = ConstantFP::get(CF->getContext(), F);
assert(!Duplicates.count(Factor) &&
"Shouldn't have two constant factors, missed a canonicalize");
unsigned Occ = ++FactorOccurrences[Factor];
if (Occ > MaxOcc) {
MaxOcc = Occ;
MaxOccVal = Factor;
}
}
}
}
}
if (MaxOcc > 1) {
DEBUG(errs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n');
++NumFactor;
Instruction *DummyInst =
I->getType()->isIntegerTy()
? BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal)
: BinaryOperator::CreateFAdd(MaxOccVal, MaxOccVal);
SmallVector<WeakVH, 4> NewMulOps;
for (unsigned i = 0; i != Ops.size(); ++i) {
BinaryOperator *BOp =
isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul);
if (!BOp)
continue;
if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) {
for (unsigned j = Ops.size(); j != i;) {
--j;
if (Ops[j].Op == Ops[i].Op) {
NewMulOps.push_back(V);
Ops.erase(Ops.begin()+j);
}
}
--i;
}
}
delete DummyInst;
unsigned NumAddedValues = NewMulOps.size();
Value *V = EmitAddTreeOfValues(I, NewMulOps);
assert(NumAddedValues > 1 && "Each occurrence should contribute a value");
(void)NumAddedValues;
if (Instruction *VI = dyn_cast<Instruction>(V))
RedoInsts.insert(VI);
Instruction *V2 = CreateMul(V, MaxOccVal, "tmp", I, I);
RedoInsts.insert(V2);
if (Ops.empty())
return V2;
Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2));
}
return nullptr;
}
bool Reassociate::collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops,
SmallVectorImpl<Factor> &Factors) {
unsigned FactorPowerSum = 0;
for (unsigned Idx = 1, Size = Ops.size(); Idx < Size; ++Idx) {
Value *Op = Ops[Idx-1].Op;
unsigned Count = 1;
for (; Idx < Size && Ops[Idx].Op == Op; ++Idx)
++Count;
if (Count > 1)
FactorPowerSum += Count;
}
if (FactorPowerSum < 4)
return false;
FactorPowerSum = 0;
for (unsigned Idx = 1; Idx < Ops.size(); ++Idx) {
Value *Op = Ops[Idx-1].Op;
unsigned Count = 1;
for (; Idx < Ops.size() && Ops[Idx].Op == Op; ++Idx)
++Count;
if (Count == 1)
continue;
Count &= ~1U;
Idx -= Count;
FactorPowerSum += Count;
Factors.push_back(Factor(Op, Count));
Ops.erase(Ops.begin()+Idx, Ops.begin()+Idx+Count);
}
assert(FactorPowerSum >= 4);
std::stable_sort(Factors.begin(), Factors.end(), Factor::PowerDescendingSorter());
return true;
}
static Value *buildMultiplyTree(IRBuilder<> &Builder,
SmallVectorImpl<Value*> &Ops) {
if (Ops.size() == 1)
return Ops.back();
Value *LHS = Ops.pop_back_val();
do {
if (LHS->getType()->isIntegerTy())
LHS = Builder.CreateMul(LHS, Ops.pop_back_val());
else
LHS = Builder.CreateFMul(LHS, Ops.pop_back_val());
} while (!Ops.empty());
return LHS;
}
Value *Reassociate::buildMinimalMultiplyDAG(IRBuilder<> &Builder,
SmallVectorImpl<Factor> &Factors) {
assert(Factors[0].Power);
SmallVector<Value *, 4> OuterProduct;
for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size();
Idx < Size && Factors[Idx].Power > 0; ++Idx) {
if (Factors[Idx].Power != Factors[LastIdx].Power) {
LastIdx = Idx;
continue;
}
SmallVector<Value *, 4> InnerProduct;
InnerProduct.push_back(Factors[LastIdx].Base);
do {
InnerProduct.push_back(Factors[Idx].Base);
++Idx;
} while (Idx < Size && Factors[Idx].Power == Factors[LastIdx].Power);
Value *M = Factors[LastIdx].Base = buildMultiplyTree(Builder, InnerProduct);
if (Instruction *MI = dyn_cast<Instruction>(M))
RedoInsts.insert(MI);
LastIdx = Idx;
}
Factors.erase(std::unique(Factors.begin(), Factors.end(),
Factor::PowerEqual()),
Factors.end());
for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) {
if (Factors[Idx].Power & 1)
OuterProduct.push_back(Factors[Idx].Base);
Factors[Idx].Power >>= 1;
}
if (Factors[0].Power) {
Value *SquareRoot = buildMinimalMultiplyDAG(Builder, Factors);
OuterProduct.push_back(SquareRoot);
OuterProduct.push_back(SquareRoot);
}
if (OuterProduct.size() == 1)
return OuterProduct.front();
Value *V = buildMultiplyTree(Builder, OuterProduct);
return V;
}
Value *Reassociate::OptimizeMul(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops) {
if (Ops.size() < 4)
return nullptr;
SmallVector<Factor, 4> Factors;
if (!collectMultiplyFactors(Ops, Factors))
return nullptr;
IRBuilder<> Builder(I);
Value *V = buildMinimalMultiplyDAG(Builder, Factors);
if (Ops.empty())
return V;
ValueEntry NewEntry = ValueEntry(getRank(V), V);
Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry);
return nullptr;
}
Value *Reassociate::OptimizeExpression(BinaryOperator *I,
SmallVectorImpl<ValueEntry> &Ops) {
Constant *Cst = nullptr;
unsigned Opcode = I->getOpcode();
while (!Ops.empty() && isa<Constant>(Ops.back().Op)) {
Constant *C = cast<Constant>(Ops.pop_back_val().Op);
Cst = Cst ? ConstantExpr::get(Opcode, C, Cst) : C;
}
if (Ops.empty())
return Cst;
if (Cst && Cst != ConstantExpr::getBinOpIdentity(Opcode, I->getType())) {
if (Cst == ConstantExpr::getBinOpAbsorber(Opcode, I->getType()))
return Cst;
Ops.push_back(ValueEntry(0, Cst));
}
if (Ops.size() == 1) return Ops[0].Op;
unsigned NumOps = Ops.size();
switch (Opcode) {
default: break;
case Instruction::And:
case Instruction::Or:
if (Value *Result = OptimizeAndOrXor(Opcode, Ops))
return Result;
break;
case Instruction::Xor:
if (Value *Result = OptimizeXor(I, Ops))
return Result;
break;
case Instruction::Add:
case Instruction::FAdd:
if (Value *Result = OptimizeAdd(I, Ops))
return Result;
break;
case Instruction::Mul:
case Instruction::FMul:
if (Value *Result = OptimizeMul(I, Ops))
return Result;
break;
}
if (Ops.size() != NumOps)
return OptimizeExpression(I, Ops);
return nullptr;
}
void Reassociate::EraseInst(Instruction *I) {
assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!");
SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end());
ValueRankMap.erase(I);
RedoInsts.remove(I);
I->eraseFromParent();
SmallPtrSet<Instruction *, 8> Visited; for (unsigned i = 0, e = Ops.size(); i != e; ++i)
if (Instruction *Op = dyn_cast<Instruction>(Ops[i])) {
unsigned Opcode = Op->getOpcode();
while (Op->hasOneUse() && Op->user_back()->getOpcode() == Opcode &&
Visited.insert(Op))
Op = Op->user_back();
RedoInsts.insert(Op);
}
}
void Reassociate::OptimizeInst(Instruction *I) {
if (!isa<BinaryOperator>(I))
return;
if (I->getOpcode() == Instruction::Shl && isa<ConstantInt>(I->getOperand(1)))
if (isReassociableOp(I->getOperand(0), Instruction::Mul) ||
(I->hasOneUse() &&
(isReassociableOp(I->user_back(), Instruction::Mul) ||
isReassociableOp(I->user_back(), Instruction::Add)))) {
Instruction *NI = ConvertShiftToMul(I);
RedoInsts.insert(I);
MadeChange = true;
I = NI;
}
if (I->getType()->isFloatingPointTy() || I->getType()->isVectorTy()) {
if (I->getOpcode() == Instruction::FMul ||
I->getOpcode() == Instruction::FAdd) {
Value *LHS = I->getOperand(0);
Value *RHS = I->getOperand(1);
unsigned LHSRank = getRank(LHS);
unsigned RHSRank = getRank(RHS);
if (RHSRank < LHSRank) {
I->setOperand(0, RHS);
I->setOperand(1, LHS);
}
}
if (I->getType()->isVectorTy() || !I->hasUnsafeAlgebra())
return;
}
if (I->getType()->isIntegerTy(1))
return;
if (I->getOpcode() == Instruction::Sub) {
if (ShouldBreakUpSubtract(I)) {
Instruction *NI = BreakUpSubtract(I);
RedoInsts.insert(I);
MadeChange = true;
I = NI;
} else if (BinaryOperator::isNeg(I)) {
if (isReassociableOp(I->getOperand(1), Instruction::Mul) &&
(!I->hasOneUse() ||
!isReassociableOp(I->user_back(), Instruction::Mul))) {
Instruction *NI = LowerNegateToMultiply(I);
RedoInsts.insert(I);
MadeChange = true;
I = NI;
}
}
} else if (I->getOpcode() == Instruction::FSub) {
if (ShouldBreakUpSubtract(I)) {
Instruction *NI = BreakUpSubtract(I);
RedoInsts.insert(I);
MadeChange = true;
I = NI;
} else if (BinaryOperator::isFNeg(I)) {
if (isReassociableOp(I->getOperand(1), Instruction::FMul) &&
(!I->hasOneUse() ||
!isReassociableOp(I->user_back(), Instruction::FMul))) {
Instruction *NI = LowerNegateToMultiply(I);
RedoInsts.insert(I);
MadeChange = true;
I = NI;
}
}
}
if (!I->isAssociative()) return;
BinaryOperator *BO = cast<BinaryOperator>(I);
unsigned Opcode = BO->getOpcode();
if (BO->hasOneUse() && BO->user_back()->getOpcode() == Opcode)
return;
if (BO->hasOneUse() && BO->getOpcode() == Instruction::Add &&
cast<Instruction>(BO->user_back())->getOpcode() == Instruction::Sub)
return;
if (BO->hasOneUse() && BO->getOpcode() == Instruction::FAdd &&
cast<Instruction>(BO->user_back())->getOpcode() == Instruction::FSub)
return;
ReassociateExpression(BO);
}
void Reassociate::ReassociateExpression(BinaryOperator *I) {
assert(!I->getType()->isVectorTy() &&
"Reassociation of vector instructions is not supported.");
SmallVector<RepeatedValue, 8> Tree;
MadeChange |= LinearizeExprTree(I, Tree);
SmallVector<ValueEntry, 8> Ops;
Ops.reserve(Tree.size());
for (unsigned i = 0, e = Tree.size(); i != e; ++i) {
RepeatedValue E = Tree[i];
Ops.append(E.second.getZExtValue(),
ValueEntry(getRank(E.first), E.first));
}
DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n');
std::stable_sort(Ops.begin(), Ops.end());
if (Value *V = OptimizeExpression(I, Ops)) {
if (V == I)
return;
DEBUG(dbgs() << "Reassoc to scalar: " << *V << '\n');
I->replaceAllUsesWith(V);
if (Instruction *VI = dyn_cast<Instruction>(V))
VI->setDebugLoc(I->getDebugLoc());
RedoInsts.insert(I);
++NumAnnihil;
return;
}
if (I->hasOneUse()) {
if (I->getOpcode() == Instruction::Mul &&
cast<Instruction>(I->user_back())->getOpcode() == Instruction::Add &&
isa<ConstantInt>(Ops.back().Op) &&
cast<ConstantInt>(Ops.back().Op)->isAllOnesValue()) {
ValueEntry Tmp = Ops.pop_back_val();
Ops.insert(Ops.begin(), Tmp);
} else if (I->getOpcode() == Instruction::FMul &&
cast<Instruction>(I->user_back())->getOpcode() ==
Instruction::FAdd &&
isa<ConstantFP>(Ops.back().Op) &&
cast<ConstantFP>(Ops.back().Op)->isExactlyValue(-1.0)) {
ValueEntry Tmp = Ops.pop_back_val();
Ops.insert(Ops.begin(), Tmp);
}
}
DEBUG(dbgs() << "RAOut:\t"; PrintOps(I, Ops); dbgs() << '\n');
if (Ops.size() == 1) {
if (Ops[0].Op == I)
return;
I->replaceAllUsesWith(Ops[0].Op);
if (Instruction *OI = dyn_cast<Instruction>(Ops[0].Op))
OI->setDebugLoc(I->getDebugLoc());
RedoInsts.insert(I);
return;
}
RewriteExprTree(I, Ops);
}
bool Reassociate::runOnFunction(Function &F) {
if (skipOptnoneFunction(F))
return false;
BuildRankMap(F);
MadeChange = false;
for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE; ++BI) {
for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE; )
if (isInstructionTriviallyDead(II)) {
EraseInst(II++);
} else {
OptimizeInst(II);
assert(II->getParent() == BI && "Moved to a different block!");
++II;
}
while (!RedoInsts.empty()) {
Instruction *I = RedoInsts.pop_back_val();
if (isInstructionTriviallyDead(I))
EraseInst(I);
else
OptimizeInst(I);
}
}
RankMap.clear();
ValueRankMap.clear();
return MadeChange;
}