InstCombineSimplifyDemanded.cpp [plain text]
#include "InstCombine.h"
#include "llvm/Target/TargetData.h"
#include "llvm/IntrinsicInst.h"
using namespace llvm;
static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
APInt Demanded) {
assert(I && "No instruction?");
assert(OpNo < I->getNumOperands() && "Operand index too large");
ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo));
if (!OpC) return false;
Demanded = Demanded.zextOrTrunc(OpC->getValue().getBitWidth());
if ((~Demanded & OpC->getValue()) == 0)
return false;
Demanded &= OpC->getValue();
I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded));
return true;
}
bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
unsigned BitWidth = Inst.getType()->getScalarSizeInBits();
APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
APInt DemandedMask(APInt::getAllOnesValue(BitWidth));
Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask,
KnownZero, KnownOne, 0);
if (V == 0) return false;
if (V == &Inst) return true;
ReplaceInstUsesWith(Inst, V);
return true;
}
bool InstCombiner::SimplifyDemandedBits(Use &U, APInt DemandedMask,
APInt &KnownZero, APInt &KnownOne,
unsigned Depth) {
Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask,
KnownZero, KnownOne, Depth);
if (NewVal == 0) return false;
U = NewVal;
return true;
}
Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt &KnownZero, APInt &KnownOne,
unsigned Depth) {
assert(V != 0 && "Null pointer of Value???");
assert(Depth <= 6 && "Limit Search Depth");
uint32_t BitWidth = DemandedMask.getBitWidth();
const Type *VTy = V->getType();
assert((TD || !VTy->isPointerTy()) &&
"SimplifyDemandedBits needs to know bit widths!");
assert((!TD || TD->getTypeSizeInBits(VTy->getScalarType()) == BitWidth) &&
(!VTy->isIntOrIntVectorTy() ||
VTy->getScalarSizeInBits() == BitWidth) &&
KnownZero.getBitWidth() == BitWidth &&
KnownOne.getBitWidth() == BitWidth &&
"Value *V, DemandedMask, KnownZero and KnownOne "
"must have same BitWidth");
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
KnownOne = CI->getValue() & DemandedMask;
KnownZero = ~KnownOne & DemandedMask;
return 0;
}
if (isa<ConstantPointerNull>(V)) {
KnownOne.clearAllBits();
KnownZero = DemandedMask;
return 0;
}
KnownZero.clearAllBits();
KnownOne.clearAllBits();
if (DemandedMask == 0) { if (isa<UndefValue>(V))
return 0;
return UndefValue::get(VTy);
}
if (Depth == 6) return 0;
APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth);
return 0; }
if (Depth != 0 && !I->hasOneUse()) {
if (I->getOpcode() == Instruction::And) {
ComputeMaskedBits(I->getOperand(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(0), DemandedMask & ~RHSKnownZero,
LHSKnownZero, LHSKnownOne, Depth+1);
if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
(DemandedMask & ~LHSKnownZero))
return I->getOperand(0);
if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) ==
(DemandedMask & ~RHSKnownZero))
return I->getOperand(1);
if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask)
return Constant::getNullValue(VTy);
} else if (I->getOpcode() == Instruction::Or) {
ComputeMaskedBits(I->getOperand(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1);
ComputeMaskedBits(I->getOperand(0), DemandedMask & ~RHSKnownOne,
LHSKnownZero, LHSKnownOne, Depth+1);
if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
(DemandedMask & ~LHSKnownOne))
return I->getOperand(0);
if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) ==
(DemandedMask & ~RHSKnownOne))
return I->getOperand(1);
if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) ==
(DemandedMask & (~RHSKnownZero)))
return I->getOperand(0);
if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) ==
(DemandedMask & (~LHSKnownZero)))
return I->getOperand(1);
}
ComputeMaskedBits(I, DemandedMask, KnownZero, KnownOne, Depth);
return 0;
}
if (Depth == 0 && !V->hasOneUse())
DemandedMask = APInt::getAllOnesValue(BitWidth);
switch (I->getOpcode()) {
default:
ComputeMaskedBits(I, DemandedMask, KnownZero, KnownOne, Depth);
break;
case Instruction::And:
if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
(DemandedMask & ~LHSKnownZero))
return I->getOperand(0);
if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) ==
(DemandedMask & ~RHSKnownZero))
return I->getOperand(1);
if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask)
return Constant::getNullValue(VTy);
if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero))
return I;
KnownOne = RHSKnownOne & LHSKnownOne;
KnownZero = RHSKnownZero | LHSKnownZero;
break;
case Instruction::Or:
if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
(DemandedMask & ~LHSKnownOne))
return I->getOperand(0);
if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) ==
(DemandedMask & ~RHSKnownOne))
return I->getOperand(1);
if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) ==
(DemandedMask & (~RHSKnownZero)))
return I->getOperand(0);
if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) ==
(DemandedMask & (~LHSKnownZero)))
return I->getOperand(1);
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
KnownZero = RHSKnownZero & LHSKnownZero;
KnownOne = RHSKnownOne | LHSKnownOne;
break;
case Instruction::Xor: {
if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
if ((DemandedMask & RHSKnownZero) == DemandedMask)
return I->getOperand(0);
if ((DemandedMask & LHSKnownZero) == DemandedMask)
return I->getOperand(1);
if ((DemandedMask & ~RHSKnownZero & ~LHSKnownZero) == 0) {
Instruction *Or =
BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
I->getName());
return InsertNewInstBefore(Or, *I);
}
if ((DemandedMask & (RHSKnownZero|RHSKnownOne)) == DemandedMask) {
if ((RHSKnownOne & LHSKnownOne) == RHSKnownOne) {
Constant *AndC = Constant::getIntegerValue(VTy,
~RHSKnownOne & DemandedMask);
Instruction *And =
BinaryOperator::CreateAnd(I->getOperand(0), AndC, "tmp");
return InsertNewInstBefore(And, *I);
}
}
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
if (Instruction *LHSInst = dyn_cast<Instruction>(I->getOperand(0)))
if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
isa<ConstantInt>(I->getOperand(1)) &&
isa<ConstantInt>(LHSInst->getOperand(1)) &&
(LHSKnownOne & RHSKnownOne & DemandedMask) != 0) {
ConstantInt *AndRHS = cast<ConstantInt>(LHSInst->getOperand(1));
ConstantInt *XorRHS = cast<ConstantInt>(I->getOperand(1));
APInt NewMask = ~(LHSKnownOne & RHSKnownOne & DemandedMask);
Constant *AndC =
ConstantInt::get(I->getType(), NewMask & AndRHS->getValue());
Instruction *NewAnd =
BinaryOperator::CreateAnd(I->getOperand(0), AndC, "tmp");
InsertNewInstBefore(NewAnd, *I);
Constant *XorC =
ConstantInt::get(I->getType(), NewMask & XorRHS->getValue());
Instruction *NewXor =
BinaryOperator::CreateXor(NewAnd, XorC, "tmp");
return InsertNewInstBefore(NewXor, *I);
}
KnownZero= (RHSKnownZero & LHSKnownZero) | (RHSKnownOne & LHSKnownOne);
KnownOne = (RHSKnownZero & LHSKnownOne) | (RHSKnownOne & LHSKnownZero);
break;
}
case Instruction::Select:
if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask,
RHSKnownZero, RHSKnownOne, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(1), DemandedMask,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
if (ShrinkDemandedConstant(I, 1, DemandedMask) ||
ShrinkDemandedConstant(I, 2, DemandedMask))
return I;
KnownOne = RHSKnownOne & LHSKnownOne;
KnownZero = RHSKnownZero & LHSKnownZero;
break;
case Instruction::Trunc: {
unsigned truncBf = I->getOperand(0)->getType()->getScalarSizeInBits();
DemandedMask = DemandedMask.zext(truncBf);
KnownZero = KnownZero.zext(truncBf);
KnownOne = KnownOne.zext(truncBf);
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
KnownZero, KnownOne, Depth+1))
return I;
DemandedMask = DemandedMask.trunc(BitWidth);
KnownZero = KnownZero.trunc(BitWidth);
KnownOne = KnownOne.trunc(BitWidth);
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
break;
}
case Instruction::BitCast:
if (!I->getOperand(0)->getType()->isIntOrIntVectorTy())
return 0;
if (const VectorType *DstVTy = dyn_cast<VectorType>(I->getType())) {
if (const VectorType *SrcVTy =
dyn_cast<VectorType>(I->getOperand(0)->getType())) {
if (DstVTy->getNumElements() != SrcVTy->getNumElements())
return 0;
} else
return 0;
} else if (I->getOperand(0)->getType()->isVectorTy())
return 0;
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
KnownZero, KnownOne, Depth+1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
break;
case Instruction::ZExt: {
unsigned SrcBitWidth =I->getOperand(0)->getType()->getScalarSizeInBits();
DemandedMask = DemandedMask.trunc(SrcBitWidth);
KnownZero = KnownZero.trunc(SrcBitWidth);
KnownOne = KnownOne.trunc(SrcBitWidth);
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask,
KnownZero, KnownOne, Depth+1))
return I;
DemandedMask = DemandedMask.zext(BitWidth);
KnownZero = KnownZero.zext(BitWidth);
KnownOne = KnownOne.zext(BitWidth);
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth);
break;
}
case Instruction::SExt: {
unsigned SrcBitWidth =I->getOperand(0)->getType()->getScalarSizeInBits();
APInt InputDemandedBits = DemandedMask &
APInt::getLowBitsSet(BitWidth, SrcBitWidth);
APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth));
if ((NewBits & DemandedMask) != 0)
InputDemandedBits.setBit(SrcBitWidth-1);
InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth);
KnownZero = KnownZero.trunc(SrcBitWidth);
KnownOne = KnownOne.trunc(SrcBitWidth);
if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits,
KnownZero, KnownOne, Depth+1))
return I;
InputDemandedBits = InputDemandedBits.zext(BitWidth);
KnownZero = KnownZero.zext(BitWidth);
KnownOne = KnownOne.zext(BitWidth);
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
if (KnownZero[SrcBitWidth-1] || (NewBits & ~DemandedMask) == NewBits) {
CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName());
return InsertNewInstBefore(NewCast, *I);
} else if (KnownOne[SrcBitWidth-1]) { KnownOne |= NewBits;
}
break;
}
case Instruction::Add: {
unsigned NLZ = DemandedMask.countLeadingZeros();
if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
if (RHS->isZero())
break;
APInt InDemandedBits(APInt::getLowBitsSet(BitWidth, BitWidth - NLZ));
if (SimplifyDemandedBits(I->getOperandUse(0), InDemandedBits,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
if (ShrinkDemandedConstant(I, 1, InDemandedBits))
return I;
if (LHSKnownZero == 0 && LHSKnownOne == 0)
break;
if ((LHSKnownZero & RHS->getValue()) == RHS->getValue()) {
Instruction *Or =
BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1),
I->getName());
return InsertNewInstBefore(Or, *I);
}
const APInt &RHSVal = RHS->getValue();
APInt CarryBits((~LHSKnownZero + RHSVal) ^ (~LHSKnownZero ^ RHSVal));
KnownOne = ((LHSKnownZero & RHSVal) |
(LHSKnownOne & ~RHSVal)) & ~CarryBits;
KnownZero = LHSKnownZero & ~RHSVal & ~CarryBits;
} else {
if (DemandedMask[BitWidth-1] == 0) {
APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
LHSKnownZero, LHSKnownOne, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
}
}
break;
}
case Instruction::Sub:
if (DemandedMask[BitWidth-1] == 0) {
uint32_t NLZ = DemandedMask.countLeadingZeros();
APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
LHSKnownZero, LHSKnownOne, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
}
ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth);
break;
case Instruction::Shl:
if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
ShlOperator *IOp = cast<ShlOperator>(I);
if (IOp->hasNoSignedWrap())
DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
else if (IOp->hasNoUnsignedWrap())
DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
KnownZero, KnownOne, Depth+1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
KnownZero <<= ShiftAmt;
KnownOne <<= ShiftAmt;
if (ShiftAmt)
KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
}
break;
case Instruction::LShr:
if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
if (cast<LShrOperator>(I)->isExact())
DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
KnownZero, KnownOne, Depth+1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
KnownOne = APIntOps::lshr(KnownOne, ShiftAmt);
if (ShiftAmt) {
APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt));
KnownZero |= HighBits; }
}
break;
case Instruction::AShr:
if (DemandedMask == 1) {
Instruction *NewVal = BinaryOperator::CreateLShr(
I->getOperand(0), I->getOperand(1), I->getName());
return InsertNewInstBefore(NewVal, *I);
}
if (DemandedMask.isSignBit())
return I->getOperand(0);
if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
if (DemandedMask.countLeadingZeros() <= ShiftAmt)
DemandedMaskIn.setBit(BitWidth-1);
if (cast<AShrOperator>(I)->isExact())
DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
KnownZero, KnownOne, Depth+1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt));
KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
KnownOne = APIntOps::lshr(KnownOne, ShiftAmt);
APInt SignBit(APInt::getSignBit(BitWidth));
SignBit = APIntOps::lshr(SignBit, ShiftAmt);
if (BitWidth <= ShiftAmt || KnownZero[BitWidth-ShiftAmt-1] ||
(HighBits & ~DemandedMask) == HighBits) {
Instruction *NewVal = BinaryOperator::CreateLShr(
I->getOperand(0), SA, I->getName());
return InsertNewInstBefore(NewVal, *I);
} else if ((KnownOne & SignBit) != 0) { KnownOne |= HighBits;
}
}
break;
case Instruction::SRem:
if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
if (Rem->isAllOnesValue())
break;
APInt RA = Rem->getValue().abs();
if (RA.isPowerOf2()) {
if (DemandedMask.ult(RA)) return I->getOperand(0);
APInt LowBits = RA - 1;
APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
if (SimplifyDemandedBits(I->getOperandUse(0), Mask2,
LHSKnownZero, LHSKnownOne, Depth+1))
return I;
KnownZero = LHSKnownZero & LowBits;
KnownOne = LHSKnownOne & LowBits;
if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits))
KnownZero |= ~LowBits;
if (LHSKnownOne[BitWidth-1] && ((LHSKnownOne & LowBits) != 0))
KnownOne |= ~LowBits;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
}
}
if (DemandedMask.isNegative() && KnownZero.isNonNegative()) {
APInt Mask2 = APInt::getSignBit(BitWidth);
APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
ComputeMaskedBits(I->getOperand(0), Mask2, LHSKnownZero, LHSKnownOne,
Depth+1);
if (LHSKnownZero.isNegative())
KnownZero |= LHSKnownZero;
}
break;
case Instruction::URem: {
APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
APInt AllOnes = APInt::getAllOnesValue(BitWidth);
if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes,
KnownZero2, KnownOne2, Depth+1) ||
SimplifyDemandedBits(I->getOperandUse(1), AllOnes,
KnownZero2, KnownOne2, Depth+1))
return I;
unsigned Leaders = KnownZero2.countLeadingOnes();
Leaders = std::max(Leaders,
KnownZero2.countLeadingOnes());
KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
break;
}
case Instruction::Call:
if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
default: break;
case Intrinsic::bswap: {
unsigned NLZ = DemandedMask.countLeadingZeros();
unsigned NTZ = DemandedMask.countTrailingZeros();
NLZ &= ~7;
NTZ &= ~7;
if (BitWidth-NLZ-NTZ == 8) {
unsigned ResultBit = NTZ;
unsigned InputBit = BitWidth-NTZ-8;
Instruction *NewVal;
if (InputBit > ResultBit)
NewVal = BinaryOperator::CreateLShr(II->getArgOperand(0),
ConstantInt::get(I->getType(), InputBit-ResultBit));
else
NewVal = BinaryOperator::CreateShl(II->getArgOperand(0),
ConstantInt::get(I->getType(), ResultBit-InputBit));
NewVal->takeName(I);
return InsertNewInstBefore(NewVal, *I);
}
break;
}
}
}
ComputeMaskedBits(V, DemandedMask, KnownZero, KnownOne, Depth);
break;
}
if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask)
return Constant::getIntegerValue(VTy, KnownOne);
return 0;
}
Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
APInt &UndefElts,
unsigned Depth) {
unsigned VWidth = cast<VectorType>(V->getType())->getNumElements();
APInt EltMask(APInt::getAllOnesValue(VWidth));
assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
if (isa<UndefValue>(V)) {
UndefElts = EltMask;
return 0;
}
if (DemandedElts == 0) { UndefElts = EltMask;
return UndefValue::get(V->getType());
}
UndefElts = 0;
if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
const Type *EltTy = cast<VectorType>(V->getType())->getElementType();
Constant *Undef = UndefValue::get(EltTy);
std::vector<Constant*> Elts;
for (unsigned i = 0; i != VWidth; ++i)
if (!DemandedElts[i]) { Elts.push_back(Undef);
UndefElts.setBit(i);
} else if (isa<UndefValue>(CV->getOperand(i))) { Elts.push_back(Undef);
UndefElts.setBit(i);
} else { Elts.push_back(CV->getOperand(i));
}
Constant *NewCP = ConstantVector::get(Elts);
return NewCP != CV ? NewCP : 0;
}
if (isa<ConstantAggregateZero>(V)) {
if (DemandedElts.isAllOnesValue())
return 0;
const Type *EltTy = cast<VectorType>(V->getType())->getElementType();
Constant *Zero = Constant::getNullValue(EltTy);
Constant *Undef = UndefValue::get(EltTy);
std::vector<Constant*> Elts;
for (unsigned i = 0; i != VWidth; ++i) {
Constant *Elt = DemandedElts[i] ? Zero : Undef;
Elts.push_back(Elt);
}
UndefElts = DemandedElts ^ EltMask;
return ConstantVector::get(Elts);
}
if (Depth == 10)
return 0;
if (!V->hasOneUse()) {
if (Depth != 0)
return 0;
DemandedElts = EltMask;
}
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return 0;
bool MadeChange = false;
APInt UndefElts2(VWidth, 0);
Value *TmpV;
switch (I->getOpcode()) {
default: break;
case Instruction::InsertElement: {
ConstantInt *Idx = dyn_cast<ConstantInt>(I->getOperand(2));
if (Idx == 0) {
TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts,
UndefElts2, Depth+1);
if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
break;
}
unsigned IdxNo = Idx->getZExtValue();
if (IdxNo >= VWidth || !DemandedElts[IdxNo]) {
Worklist.Add(I);
return I->getOperand(0);
}
APInt DemandedElts2 = DemandedElts;
DemandedElts2.clearBit(IdxNo);
TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts2,
UndefElts, Depth+1);
if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
UndefElts.clearBit(IdxNo);
break;
}
case Instruction::ShuffleVector: {
ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
uint64_t LHSVWidth =
cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements();
APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0);
for (unsigned i = 0; i < VWidth; i++) {
if (DemandedElts[i]) {
unsigned MaskVal = Shuffle->getMaskValue(i);
if (MaskVal != -1u) {
assert(MaskVal < LHSVWidth * 2 &&
"shufflevector mask index out of range!");
if (MaskVal < LHSVWidth)
LeftDemanded.setBit(MaskVal);
else
RightDemanded.setBit(MaskVal - LHSVWidth);
}
}
}
APInt UndefElts4(LHSVWidth, 0);
TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded,
UndefElts4, Depth+1);
if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
APInt UndefElts3(LHSVWidth, 0);
TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded,
UndefElts3, Depth+1);
if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; }
bool NewUndefElts = false;
for (unsigned i = 0; i < VWidth; i++) {
unsigned MaskVal = Shuffle->getMaskValue(i);
if (MaskVal == -1u) {
UndefElts.setBit(i);
} else if (MaskVal < LHSVWidth) {
if (UndefElts4[MaskVal]) {
NewUndefElts = true;
UndefElts.setBit(i);
}
} else {
if (UndefElts3[MaskVal - LHSVWidth]) {
NewUndefElts = true;
UndefElts.setBit(i);
}
}
}
if (NewUndefElts) {
std::vector<Constant*> Elts;
for (unsigned i = 0; i < VWidth; ++i) {
if (UndefElts[i])
Elts.push_back(UndefValue::get(Type::getInt32Ty(I->getContext())));
else
Elts.push_back(ConstantInt::get(Type::getInt32Ty(I->getContext()),
Shuffle->getMaskValue(i)));
}
I->setOperand(2, ConstantVector::get(Elts));
MadeChange = true;
}
break;
}
case Instruction::BitCast: {
const VectorType *VTy = dyn_cast<VectorType>(I->getOperand(0)->getType());
if (!VTy) break;
unsigned InVWidth = VTy->getNumElements();
APInt InputDemandedElts(InVWidth, 0);
unsigned Ratio;
if (VWidth == InVWidth) {
Ratio = 1;
InputDemandedElts = DemandedElts;
} else if (VWidth > InVWidth) {
break;
Ratio = VWidth/InVWidth;
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
if (DemandedElts[OutIdx])
InputDemandedElts.setBit(OutIdx/Ratio);
}
} else {
break;
Ratio = InVWidth/VWidth;
for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
if (DemandedElts[InIdx/Ratio])
InputDemandedElts.setBit(InIdx);
}
TmpV = SimplifyDemandedVectorElts(I->getOperand(0), InputDemandedElts,
UndefElts2, Depth+1);
if (TmpV) {
I->setOperand(0, TmpV);
MadeChange = true;
}
UndefElts = UndefElts2;
if (VWidth > InVWidth) {
llvm_unreachable("Unimp");
for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
if (UndefElts2[OutIdx/Ratio])
UndefElts.setBit(OutIdx);
} else if (VWidth < InVWidth) {
llvm_unreachable("Unimp");
UndefElts = ~0ULL >> (64-VWidth); for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
if (!UndefElts2[InIdx]) UndefElts.clearBit(InIdx/Ratio); }
break;
}
case Instruction::And:
case Instruction::Or:
case Instruction::Xor:
case Instruction::Add:
case Instruction::Sub:
case Instruction::Mul:
TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts,
UndefElts, Depth+1);
if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
TmpV = SimplifyDemandedVectorElts(I->getOperand(1), DemandedElts,
UndefElts2, Depth+1);
if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; }
UndefElts &= UndefElts2;
break;
case Instruction::Call: {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(I);
if (!II) break;
switch (II->getIntrinsicID()) {
default: break;
case Intrinsic::x86_sse_sub_ss:
case Intrinsic::x86_sse_mul_ss:
case Intrinsic::x86_sse_min_ss:
case Intrinsic::x86_sse_max_ss:
case Intrinsic::x86_sse2_sub_sd:
case Intrinsic::x86_sse2_mul_sd:
case Intrinsic::x86_sse2_min_sd:
case Intrinsic::x86_sse2_max_sd:
TmpV = SimplifyDemandedVectorElts(II->getArgOperand(0), DemandedElts,
UndefElts, Depth+1);
if (TmpV) { II->setArgOperand(0, TmpV); MadeChange = true; }
TmpV = SimplifyDemandedVectorElts(II->getArgOperand(1), DemandedElts,
UndefElts2, Depth+1);
if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; }
if (DemandedElts == 1) {
switch (II->getIntrinsicID()) {
default: break;
case Intrinsic::x86_sse_sub_ss:
case Intrinsic::x86_sse_mul_ss:
case Intrinsic::x86_sse2_sub_sd:
case Intrinsic::x86_sse2_mul_sd:
Value *LHS = II->getArgOperand(0);
Value *RHS = II->getArgOperand(1);
LHS = InsertNewInstBefore(ExtractElementInst::Create(LHS,
ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U)), *II);
RHS = InsertNewInstBefore(ExtractElementInst::Create(RHS,
ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U)), *II);
switch (II->getIntrinsicID()) {
default: llvm_unreachable("Case stmts out of sync!");
case Intrinsic::x86_sse_sub_ss:
case Intrinsic::x86_sse2_sub_sd:
TmpV = InsertNewInstBefore(BinaryOperator::CreateFSub(LHS, RHS,
II->getName()), *II);
break;
case Intrinsic::x86_sse_mul_ss:
case Intrinsic::x86_sse2_mul_sd:
TmpV = InsertNewInstBefore(BinaryOperator::CreateFMul(LHS, RHS,
II->getName()), *II);
break;
}
Instruction *New =
InsertElementInst::Create(
UndefValue::get(II->getType()), TmpV,
ConstantInt::get(Type::getInt32Ty(I->getContext()), 0U, false),
II->getName());
InsertNewInstBefore(New, *II);
return New;
}
}
UndefElts &= UndefElts2;
break;
}
break;
}
}
return MadeChange ? I : 0;
}