InstCombineAddSub.cpp [plain text]
#include "InstCombine.h"
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Support/GetElementPtrTypeIterator.h"
#include "llvm/Support/PatternMatch.h"
using namespace llvm;
using namespace PatternMatch;
static Constant *AddOne(Constant *C) {
return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1));
}
static Constant *SubOne(ConstantInt *C) {
return ConstantInt::get(C->getContext(), C->getValue()-1);
}
static inline Value *dyn_castFoldableMul(Value *V, ConstantInt *&CST) {
if (!V->hasOneUse() || !V->getType()->isIntegerTy())
return 0;
Instruction *I = dyn_cast<Instruction>(V);
if (I == 0) return 0;
if (I->getOpcode() == Instruction::Mul)
if ((CST = dyn_cast<ConstantInt>(I->getOperand(1))))
return I->getOperand(0);
if (I->getOpcode() == Instruction::Shl)
if ((CST = dyn_cast<ConstantInt>(I->getOperand(1)))) {
uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
uint32_t CSTVal = CST->getLimitedValue(BitWidth);
CST = ConstantInt::get(V->getType()->getContext(),
APInt(BitWidth, 1).shl(CSTVal));
return I->getOperand(0);
}
return 0;
}
bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS) {
if (ComputeNumSignBits(LHS) > 1 && ComputeNumSignBits(RHS) > 1)
return true;
return false;
}
Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
bool Changed = SimplifyCommutative(I);
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
I.hasNoUnsignedWrap(), TD))
return ReplaceInstUsesWith(I, V);
if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHSC)) {
const APInt& Val = CI->getValue();
uint32_t BitWidth = Val.getBitWidth();
if (Val == APInt::getSignBit(BitWidth))
return BinaryOperator::CreateXor(LHS, RHS);
if (SimplifyDemandedInstructionBits(I))
return &I;
if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext()))
return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI);
}
if (isa<PHINode>(LHS))
if (Instruction *NV = FoldOpIntoPhi(I))
return NV;
ConstantInt *XorRHS = 0;
Value *XorLHS = 0;
if (isa<ConstantInt>(RHSC) &&
match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
uint32_t TySizeBits = I.getType()->getScalarSizeInBits();
const APInt& RHSVal = cast<ConstantInt>(RHSC)->getValue();
unsigned ExtendAmt = 0;
if (XorRHS->getValue() == -RHSVal) {
if (RHSVal.isPowerOf2())
ExtendAmt = TySizeBits - RHSVal.logBase2() - 1;
else if (XorRHS->getValue().isPowerOf2())
ExtendAmt = TySizeBits - XorRHS->getValue().logBase2() - 1;
}
if (ExtendAmt) {
APInt Mask = APInt::getHighBitsSet(TySizeBits, ExtendAmt);
if (!MaskedValueIsZero(XorLHS, Mask))
ExtendAmt = 0;
}
if (ExtendAmt) {
Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt);
Value *NewShl = Builder->CreateShl(XorLHS, ShAmt, "sext");
return BinaryOperator::CreateAShr(NewShl, ShAmt);
}
}
}
if (I.getType()->isIntegerTy(1))
return BinaryOperator::CreateXor(LHS, RHS);
if (I.getType()->isIntegerTy()) {
if (LHS == RHS)
return BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1));
if (Instruction *RHSI = dyn_cast<Instruction>(RHS)) {
if (RHSI->getOpcode() == Instruction::Sub)
if (LHS == RHSI->getOperand(1)) return ReplaceInstUsesWith(I, RHSI->getOperand(0));
}
if (Instruction *LHSI = dyn_cast<Instruction>(LHS)) {
if (LHSI->getOpcode() == Instruction::Sub)
if (RHS == LHSI->getOperand(1)) return ReplaceInstUsesWith(I, LHSI->getOperand(0));
}
}
if (Value *LHSV = dyn_castNegVal(LHS)) {
if (LHS->getType()->isIntOrIntVectorTy()) {
if (Value *RHSV = dyn_castNegVal(RHS)) {
Value *NewAdd = Builder->CreateAdd(LHSV, RHSV, "sum");
return BinaryOperator::CreateNeg(NewAdd);
}
}
return BinaryOperator::CreateSub(RHS, LHSV);
}
if (!isa<Constant>(RHS))
if (Value *V = dyn_castNegVal(RHS))
return BinaryOperator::CreateSub(LHS, V);
ConstantInt *C2;
if (Value *X = dyn_castFoldableMul(LHS, C2)) {
if (X == RHS) return BinaryOperator::CreateMul(RHS, AddOne(C2));
ConstantInt *C1;
if (X == dyn_castFoldableMul(RHS, C1))
return BinaryOperator::CreateMul(X, ConstantExpr::getAdd(C1, C2));
}
if (dyn_castFoldableMul(RHS, C2) == LHS)
return BinaryOperator::CreateMul(LHS, AddOne(C2));
if (match(LHS, m_Not(m_Specific(RHS))) ||
match(RHS, m_Not(m_Specific(LHS))))
return ReplaceInstUsesWith(I, Constant::getAllOnesValue(I.getType()));
if (const IntegerType *IT = dyn_cast<IntegerType>(I.getType())) {
APInt Mask = APInt::getAllOnesValue(IT->getBitWidth());
APInt LHSKnownOne(IT->getBitWidth(), 0);
APInt LHSKnownZero(IT->getBitWidth(), 0);
ComputeMaskedBits(LHS, Mask, LHSKnownZero, LHSKnownOne);
if (LHSKnownZero != 0) {
APInt RHSKnownOne(IT->getBitWidth(), 0);
APInt RHSKnownZero(IT->getBitWidth(), 0);
ComputeMaskedBits(RHS, Mask, RHSKnownZero, RHSKnownOne);
if ((LHSKnownZero|RHSKnownZero).isAllOnesValue())
return BinaryOperator::CreateOr(LHS, RHS);
}
}
if (I.getType()->isIntOrIntVectorTy()) {
Value *W, *X, *Y, *Z;
if (match(LHS, m_Mul(m_Value(W), m_Value(X))) &&
match(RHS, m_Mul(m_Value(Y), m_Value(Z)))) {
if (W != Y) {
if (W == Z) {
std::swap(Y, Z);
} else if (Y == X) {
std::swap(W, X);
} else if (X == Z) {
std::swap(Y, Z);
std::swap(W, X);
}
}
if (W == Y) {
Value *NewAdd = Builder->CreateAdd(X, Z, LHS->getName());
return BinaryOperator::CreateMul(W, NewAdd);
}
}
}
if (ConstantInt *CRHS = dyn_cast<ConstantInt>(RHS)) {
Value *X = 0;
if (match(LHS, m_Not(m_Value(X)))) return BinaryOperator::CreateSub(SubOne(CRHS), X);
if (LHS->hasOneUse() &&
match(LHS, m_And(m_Value(X), m_ConstantInt(C2)))) {
Constant *Anded = ConstantExpr::getAnd(CRHS, C2);
if (Anded == CRHS) {
const APInt &AddRHSV = CRHS->getValue();
APInt AddRHSHighBits(~((AddRHSV & -AddRHSV)-1));
APInt AddRHSHighBitsAnd(AddRHSHighBits & C2->getValue());
if (AddRHSHighBits == AddRHSHighBitsAnd) {
Value *NewAdd = Builder->CreateAdd(X, CRHS, LHS->getName());
return BinaryOperator::CreateAnd(NewAdd, C2);
}
}
}
if (SelectInst *SI = dyn_cast<SelectInst>(LHS))
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
}
{
SelectInst *SI = dyn_cast<SelectInst>(LHS);
Value *A = RHS;
if (!SI) {
SI = dyn_cast<SelectInst>(RHS);
A = LHS;
}
if (SI && SI->hasOneUse()) {
Value *TV = SI->getTrueValue();
Value *FV = SI->getFalseValue();
Value *N;
if (match(FV, m_Zero()) &&
match(TV, m_Sub(m_Value(N), m_Specific(A))))
return SelectInst::Create(SI->getCondition(), N, A);
if (match(TV, m_Zero()) &&
match(FV, m_Sub(m_Value(N), m_Specific(A))))
return SelectInst::Create(SI->getCondition(), A, N);
}
}
if (SExtInst *LHSConv = dyn_cast<SExtInst>(LHS)) {
if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) {
Constant *CI =
ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType());
if (LHSConv->hasOneUse() &&
ConstantExpr::getSExt(CI, I.getType()) == RHSC &&
WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
CI, "addconv");
return new SExtInst(NewAdd, I.getType());
}
}
if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) {
if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
WillNotOverflowSignedAdd(LHSConv->getOperand(0),
RHSConv->getOperand(0))) {
Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
RHSConv->getOperand(0), "addconv");
return new SExtInst(NewAdd, I.getType());
}
}
}
return Changed ? &I : 0;
}
Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
bool Changed = SimplifyCommutative(I);
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
if (Constant *RHSC = dyn_cast<Constant>(RHS)) {
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHSC)) {
if (CFP->isExactlyValue(ConstantFP::getNegativeZero
(I.getType())->getValueAPF()))
return ReplaceInstUsesWith(I, LHS);
}
if (isa<PHINode>(LHS))
if (Instruction *NV = FoldOpIntoPhi(I))
return NV;
}
if (Value *LHSV = dyn_castFNegVal(LHS))
return BinaryOperator::CreateFSub(RHS, LHSV);
if (!isa<Constant>(RHS))
if (Value *V = dyn_castFNegVal(RHS))
return BinaryOperator::CreateFSub(LHS, V);
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
if (CFP->getValueAPF().isPosZero() && CannotBeNegativeZero(LHS))
return ReplaceInstUsesWith(I, LHS);
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) {
Constant *CI =
ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
if (LHSConv->hasOneUse() &&
ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI)) {
Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
CI, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
}
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
if (LHSConv->getOperand(0)->getType()==RHSConv->getOperand(0)->getType()&&
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
WillNotOverflowSignedAdd(LHSConv->getOperand(0),
RHSConv->getOperand(0))) {
Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
RHSConv->getOperand(0),"addconv");
return new SIToFPInst(NewAdd, I.getType());
}
}
}
return Changed ? &I : 0;
}
Value *InstCombiner::EmitGEPOffset(User *GEP) {
TargetData &TD = *getTargetData();
gep_type_iterator GTI = gep_type_begin(GEP);
const Type *IntPtrTy = TD.getIntPtrType(GEP->getContext());
Value *Result = Constant::getNullValue(IntPtrTy);
unsigned IntPtrWidth = TD.getPointerSizeInBits();
uint64_t PtrSizeMask = ~0ULL >> (64-IntPtrWidth);
for (User::op_iterator i = GEP->op_begin() + 1, e = GEP->op_end(); i != e;
++i, ++GTI) {
Value *Op = *i;
uint64_t Size = TD.getTypeAllocSize(GTI.getIndexedType()) & PtrSizeMask;
if (ConstantInt *OpC = dyn_cast<ConstantInt>(Op)) {
if (OpC->isZero()) continue;
if (const StructType *STy = dyn_cast<StructType>(*GTI)) {
Size = TD.getStructLayout(STy)->getElementOffset(OpC->getZExtValue());
Result = Builder->CreateAdd(Result,
ConstantInt::get(IntPtrTy, Size),
GEP->getName()+".offs");
continue;
}
Constant *Scale = ConstantInt::get(IntPtrTy, Size);
Constant *OC =
ConstantExpr::getIntegerCast(OpC, IntPtrTy, true );
Scale = ConstantExpr::getMul(OC, Scale);
Result = Builder->CreateAdd(Result, Scale, GEP->getName()+".offs");
continue;
}
if (Op->getType() != IntPtrTy)
Op = Builder->CreateIntCast(Op, IntPtrTy, true, Op->getName()+".c");
if (Size != 1) {
Constant *Scale = ConstantInt::get(IntPtrTy, Size);
Op = Builder->CreateMul(Op, Scale, GEP->getName()+".idx");
}
Result = Builder->CreateAdd(Op, Result, GEP->getName()+".offs");
}
return Result;
}
Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS,
const Type *Ty) {
assert(TD && "Must have target data info for this");
bool Swapped = false;
GetElementPtrInst *GEP = 0;
ConstantExpr *CstGEP = 0;
if (GetElementPtrInst *LHSGEP = dyn_cast<GetElementPtrInst>(LHS)) {
if (LHSGEP->getOperand(0) == RHS) {
GEP = LHSGEP;
Swapped = false;
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(RHS)) {
if (CE->getOpcode() == Instruction::GetElementPtr &&
LHSGEP->getOperand(0) == CE->getOperand(0)) {
CstGEP = CE;
GEP = LHSGEP;
Swapped = false;
}
}
}
if (GetElementPtrInst *RHSGEP = dyn_cast<GetElementPtrInst>(RHS)) {
if (RHSGEP->getOperand(0) == LHS) {
GEP = RHSGEP;
Swapped = true;
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(LHS)) {
if (CE->getOpcode() == Instruction::GetElementPtr &&
RHSGEP->getOperand(0) == CE->getOperand(0)) {
CstGEP = CE;
GEP = RHSGEP;
Swapped = true;
}
}
}
if (GEP == 0)
return 0;
Value *Result = EmitGEPOffset(GEP);
if (CstGEP) {
Value *CstOffset = EmitGEPOffset(CstGEP);
Result = Builder->CreateSub(Result, CstOffset);
}
if (Swapped)
Result = Builder->CreateNeg(Result, "diff.neg");
return Builder->CreateIntCast(Result, Ty, true);
}
Instruction *InstCombiner::visitSub(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Op0 == Op1) return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
if (Value *V = dyn_castNegVal(Op1)) {
BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V);
Res->setHasNoSignedWrap(I.hasNoSignedWrap());
Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
return Res;
}
if (isa<UndefValue>(Op0))
return ReplaceInstUsesWith(I, Op0); if (isa<UndefValue>(Op1))
return ReplaceInstUsesWith(I, Op1); if (I.getType()->isIntegerTy(1))
return BinaryOperator::CreateXor(Op0, Op1);
if (ConstantInt *C = dyn_cast<ConstantInt>(Op0)) {
if (C->isAllOnesValue())
return BinaryOperator::CreateNot(Op1);
Value *X = 0;
if (match(Op1, m_Not(m_Value(X))))
return BinaryOperator::CreateAdd(X, AddOne(C));
if (C->isZero()) {
if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op1)) {
if (SI->getOpcode() == Instruction::LShr) {
if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) {
if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) ==
SI->getType()->getPrimitiveSizeInBits()-1) {
return BinaryOperator::Create(Instruction::AShr,
SI->getOperand(0), CU, SI->getName());
}
}
} else if (SI->getOpcode() == Instruction::AShr) {
if (ConstantInt *CU = dyn_cast<ConstantInt>(SI->getOperand(1))) {
if (CU->getLimitedValue(SI->getType()->getPrimitiveSizeInBits()) ==
SI->getType()->getPrimitiveSizeInBits()-1) {
return BinaryOperator::CreateLShr(
SI->getOperand(0), CU, SI->getName());
}
}
}
}
}
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
if (ZExtInst *ZI = dyn_cast<ZExtInst>(Op1))
if (ZI->getSrcTy() == Type::getInt1Ty(I.getContext()))
return SelectInst::Create(ZI->getOperand(0), SubOne(C), C);
}
if (BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1)) {
if (Op1I->getOpcode() == Instruction::Add) {
if (Op1I->getOperand(0) == Op0) return BinaryOperator::CreateNeg(Op1I->getOperand(1),
I.getName());
else if (Op1I->getOperand(1) == Op0) return BinaryOperator::CreateNeg(Op1I->getOperand(0),
I.getName());
else if (ConstantInt *CI1 = dyn_cast<ConstantInt>(I.getOperand(0))) {
if (ConstantInt *CI2 = dyn_cast<ConstantInt>(Op1I->getOperand(1)))
return BinaryOperator::CreateSub(
ConstantExpr::getSub(CI1, CI2), Op1I->getOperand(0));
}
}
if (Op1I->hasOneUse()) {
if (Op1I->getOpcode() == Instruction::Sub) {
Value *IIOp0 = Op1I->getOperand(0), *IIOp1 = Op1I->getOperand(1);
Op1I->setOperand(0, IIOp1);
Op1I->setOperand(1, IIOp0);
return BinaryOperator::CreateAdd(Op0, Op1);
}
if (Op1I->getOpcode() == Instruction::And &&
(Op1I->getOperand(0) == Op0 || Op1I->getOperand(1) == Op0)) {
Value *OtherOp = Op1I->getOperand(Op1I->getOperand(0) == Op0);
Value *NewNot = Builder->CreateNot(OtherOp, "B.not");
return BinaryOperator::CreateAnd(Op0, NewNot);
}
if (Op1I->getOpcode() == Instruction::SDiv)
if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0))
if (CSI->isZero())
if (Constant *DivRHS = dyn_cast<Constant>(Op1I->getOperand(1)))
return BinaryOperator::CreateSDiv(Op1I->getOperand(0),
ConstantExpr::getNeg(DivRHS));
if (Op1I->getOpcode() == Instruction::Shl)
if (ConstantInt *CSI = dyn_cast<ConstantInt>(Op0))
if (CSI->isZero())
if (Value *ShlLHSNeg = dyn_castNegVal(Op1I->getOperand(0)))
return BinaryOperator::CreateShl(ShlLHSNeg, Op1I->getOperand(1));
ConstantInt *C2 = 0;
if (dyn_castFoldableMul(Op1I, C2) == Op0) {
Constant *CP1 =
ConstantExpr::getSub(ConstantInt::get(I.getType(), 1),
C2);
return BinaryOperator::CreateMul(Op0, CP1);
}
}
}
if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
if (Op0I->getOpcode() == Instruction::Add) {
if (Op0I->getOperand(0) == Op1) return ReplaceInstUsesWith(I, Op0I->getOperand(1));
else if (Op0I->getOperand(1) == Op1) return ReplaceInstUsesWith(I, Op0I->getOperand(0));
} else if (Op0I->getOpcode() == Instruction::Sub) {
if (Op0I->getOperand(0) == Op1) return BinaryOperator::CreateNeg(Op0I->getOperand(1),
I.getName());
}
}
ConstantInt *C1;
if (Value *X = dyn_castFoldableMul(Op0, C1)) {
if (X == Op1) return BinaryOperator::CreateMul(Op1, SubOne(C1));
ConstantInt *C2; if (X == dyn_castFoldableMul(Op1, C2))
return BinaryOperator::CreateMul(X, ConstantExpr::getSub(C1, C2));
}
if (TD) {
Value *LHSOp, *RHSOp;
if (match(Op0, m_PtrToInt(m_Value(LHSOp))) &&
match(Op1, m_PtrToInt(m_Value(RHSOp))))
if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
return ReplaceInstUsesWith(I, Res);
if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) &&
match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp)))))
if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType()))
return ReplaceInstUsesWith(I, Res);
}
return 0;
}
Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = dyn_castFNegVal(Op1))
return BinaryOperator::CreateFAdd(Op0, V);
return 0;
}