ConstantFolding.cpp [plain text]
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Constants.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Function.h"
#include "llvm/Instructions.h"
#include "llvm/Intrinsics.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Target/TargetData.h"
#include "llvm/Support/GetElementPtrTypeIterator.h"
#include "llvm/Support/MathExtras.h"
#include <cerrno>
#include <cmath>
using namespace llvm;
static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
int64_t &Offset, const TargetData &TD) {
if ((GV = dyn_cast<GlobalValue>(C))) {
Offset = 0;
return true;
}
ConstantExpr *CE = dyn_cast<ConstantExpr>(C);
if (!CE) return false;
if (CE->getOpcode() == Instruction::PtrToInt ||
CE->getOpcode() == Instruction::BitCast)
return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD);
if (CE->getOpcode() == Instruction::GetElementPtr) {
if (!cast<PointerType>(CE->getOperand(0)->getType())
->getElementType()->isSized())
return false;
if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD))
return false;
gep_type_iterator GTI = gep_type_begin(CE);
for (User::const_op_iterator i = CE->op_begin() + 1, e = CE->op_end();
i != e; ++i, ++GTI) {
ConstantInt *CI = dyn_cast<ConstantInt>(*i);
if (!CI) return false; if (CI->getZExtValue() == 0) continue;
if (const StructType *ST = dyn_cast<StructType>(*GTI)) {
Offset += TD.getStructLayout(ST)->getElementOffset(CI->getZExtValue());
} else {
const SequentialType *SQT = cast<SequentialType>(*GTI);
Offset += TD.getTypePaddedSize(SQT->getElementType())*CI->getSExtValue();
}
}
return true;
}
return false;
}
static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0,
Constant *Op1, const TargetData *TD){
if (Opc == Instruction::Sub && TD) {
GlobalValue *GV1, *GV2;
int64_t Offs1, Offs2;
if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, *TD))
if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, *TD) &&
GV1 == GV2) {
return ConstantInt::get(Op0->getType(), Offs1-Offs2);
}
}
return 0;
}
static Constant *SymbolicallyEvaluateGEP(Constant* const* Ops, unsigned NumOps,
const Type *ResultTy,
const TargetData *TD) {
Constant *Ptr = Ops[0];
if (!TD || !cast<PointerType>(Ptr->getType())->getElementType()->isSized())
return 0;
uint64_t BasePtr = 0;
if (!Ptr->isNullValue()) {
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr))
if (CE->getOpcode() == Instruction::IntToPtr)
if (ConstantInt *Base = dyn_cast<ConstantInt>(CE->getOperand(0)))
BasePtr = Base->getZExtValue();
if (BasePtr == 0)
return 0;
}
for (unsigned i = 1; i != NumOps; ++i)
if (!isa<ConstantInt>(Ops[i]))
return false;
uint64_t Offset = TD->getIndexedOffset(Ptr->getType(),
(Value**)Ops+1, NumOps-1);
Constant *C = ConstantInt::get(TD->getIntPtrType(), Offset+BasePtr);
return ConstantExpr::getIntToPtr(C, ResultTy);
}
static Constant *FoldBitCast(Constant *C, const Type *DestTy,
const TargetData &TD) {
if (ConstantVector *CV = dyn_cast<ConstantVector>(C)) {
if (const VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) {
unsigned NumDstElt = DestVTy->getNumElements();
unsigned NumSrcElt = CV->getNumOperands();
if (NumDstElt == NumSrcElt)
return 0;
const Type *SrcEltTy = CV->getType()->getElementType();
const Type *DstEltTy = DestVTy->getElementType();
if (DstEltTy->isFloatingPoint()) {
unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits();
const Type *DestIVTy = VectorType::get(IntegerType::get(FPWidth),
NumDstElt);
C = FoldBitCast(C, DestIVTy, TD);
if (!C) return 0;
return ConstantExpr::getBitCast(C, DestTy);
}
if (SrcEltTy->isFloatingPoint()) {
unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
const Type *SrcIVTy = VectorType::get(IntegerType::get(FPWidth),
NumSrcElt);
C = ConstantExpr::getBitCast(C, SrcIVTy);
CV = dyn_cast<ConstantVector>(C);
if (!CV) return 0; }
bool isLittleEndian = TD.isLittleEndian();
SmallVector<Constant*, 32> Result;
if (NumDstElt < NumSrcElt) {
Constant *Zero = Constant::getNullValue(DstEltTy);
unsigned Ratio = NumSrcElt/NumDstElt;
unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits();
unsigned SrcElt = 0;
for (unsigned i = 0; i != NumDstElt; ++i) {
Constant *Elt = Zero;
unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1);
for (unsigned j = 0; j != Ratio; ++j) {
Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(SrcElt++));
if (!Src) return 0;
Src = ConstantExpr::getZExt(Src, Elt->getType());
Src = ConstantExpr::getShl(Src,
ConstantInt::get(Src->getType(), ShiftAmt));
ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize;
Elt = ConstantExpr::getOr(Elt, Src);
}
Result.push_back(Elt);
}
} else {
unsigned Ratio = NumDstElt/NumSrcElt;
unsigned DstBitSize = DstEltTy->getPrimitiveSizeInBits();
for (unsigned i = 0; i != NumSrcElt; ++i) {
Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(i));
if (!Src) return 0;
unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1);
for (unsigned j = 0; j != Ratio; ++j) {
Constant *Elt = ConstantExpr::getLShr(Src,
ConstantInt::get(Src->getType(), ShiftAmt));
ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize;
Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy));
}
}
}
return ConstantVector::get(&Result[0], Result.size());
}
}
return 0;
}
Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
if (PHINode *PN = dyn_cast<PHINode>(I)) {
if (PN->getNumIncomingValues() == 0)
return UndefValue::get(PN->getType());
Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0));
if (Result == 0) return 0;
for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
if (PN->getIncomingValue(i) != Result && PN->getIncomingValue(i) != PN)
return 0;
return Result;
}
SmallVector<Constant*, 8> Ops;
for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i)
if (Constant *Op = dyn_cast<Constant>(*i))
Ops.push_back(Op);
else
return 0;
if (const CmpInst *CI = dyn_cast<CmpInst>(I))
return ConstantFoldCompareInstOperands(CI->getPredicate(),
&Ops[0], Ops.size(), TD);
else
return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
&Ops[0], Ops.size(), TD);
}
Constant *llvm::ConstantFoldConstantExpression(ConstantExpr *CE,
const TargetData *TD) {
assert(TD && "ConstantFoldConstantExpression requires a valid TargetData.");
SmallVector<Constant*, 8> Ops;
for (User::op_iterator i = CE->op_begin(), e = CE->op_end(); i != e; ++i)
Ops.push_back(cast<Constant>(*i));
if (CE->isCompare())
return ConstantFoldCompareInstOperands(CE->getPredicate(),
&Ops[0], Ops.size(), TD);
else
return ConstantFoldInstOperands(CE->getOpcode(), CE->getType(),
&Ops[0], Ops.size(), TD);
}
Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, const Type *DestTy,
Constant* const* Ops, unsigned NumOps,
const TargetData *TD) {
if (Instruction::isBinaryOp(Opcode)) {
if (isa<ConstantExpr>(Ops[0]) || isa<ConstantExpr>(Ops[1]))
if (Constant *C = SymbolicallyEvaluateBinop(Opcode, Ops[0], Ops[1], TD))
return C;
return ConstantExpr::get(Opcode, Ops[0], Ops[1]);
}
switch (Opcode) {
default: return 0;
case Instruction::Call:
if (Function *F = dyn_cast<Function>(Ops[0]))
if (canConstantFoldCallTo(F))
return ConstantFoldCall(F, Ops+1, NumOps-1);
return 0;
case Instruction::ICmp:
case Instruction::FCmp:
case Instruction::VICmp:
case Instruction::VFCmp:
assert(0 &&"This function is invalid for compares: no predicate specified");
case Instruction::PtrToInt:
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) {
if (TD && CE->getOpcode() == Instruction::IntToPtr) {
Constant *Input = CE->getOperand(0);
unsigned InWidth = Input->getType()->getPrimitiveSizeInBits();
if (TD->getPointerSizeInBits() < InWidth) {
Constant *Mask =
ConstantInt::get(APInt::getLowBitsSet(InWidth,
TD->getPointerSizeInBits()));
Input = ConstantExpr::getAnd(Input, Mask);
}
return ConstantExpr::getIntegerCast(Input, DestTy, false);
}
}
return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
case Instruction::IntToPtr:
if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) {
if (TD && CE->getOpcode() == Instruction::PtrToInt &&
TD->getPointerSizeInBits() <=
CE->getType()->getPrimitiveSizeInBits()) {
Constant *Input = CE->getOperand(0);
Constant *C = FoldBitCast(Input, DestTy, *TD);
return C ? C : ConstantExpr::getBitCast(Input, DestTy);
}
}
return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
case Instruction::Trunc:
case Instruction::ZExt:
case Instruction::SExt:
case Instruction::FPTrunc:
case Instruction::FPExt:
case Instruction::UIToFP:
case Instruction::SIToFP:
case Instruction::FPToUI:
case Instruction::FPToSI:
return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
case Instruction::BitCast:
if (TD)
if (Constant *C = FoldBitCast(Ops[0], DestTy, *TD))
return C;
return ConstantExpr::getBitCast(Ops[0], DestTy);
case Instruction::Select:
return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]);
case Instruction::ExtractElement:
return ConstantExpr::getExtractElement(Ops[0], Ops[1]);
case Instruction::InsertElement:
return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]);
case Instruction::ShuffleVector:
return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]);
case Instruction::GetElementPtr:
if (Constant *C = SymbolicallyEvaluateGEP(Ops, NumOps, DestTy, TD))
return C;
return ConstantExpr::getGetElementPtr(Ops[0], Ops+1, NumOps-1);
}
}
Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
Constant*const * Ops,
unsigned NumOps,
const TargetData *TD) {
if (ConstantExpr *CE0 = dyn_cast<ConstantExpr>(Ops[0])) {
if (TD && Ops[1]->isNullValue()) {
const Type *IntPtrTy = TD->getIntPtrType();
if (CE0->getOpcode() == Instruction::IntToPtr) {
Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0),
IntPtrTy, false);
Constant *NewOps[] = { C, Constant::getNullValue(C->getType()) };
return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
}
if (CE0->getOpcode() == Instruction::PtrToInt &&
CE0->getType() == IntPtrTy) {
Constant *C = CE0->getOperand(0);
Constant *NewOps[] = { C, Constant::getNullValue(C->getType()) };
return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
}
}
if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(Ops[1])) {
if (TD && CE0->getOpcode() == CE1->getOpcode()) {
const Type *IntPtrTy = TD->getIntPtrType();
if (CE0->getOpcode() == Instruction::IntToPtr) {
Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0),
IntPtrTy, false);
Constant *C1 = ConstantExpr::getIntegerCast(CE1->getOperand(0),
IntPtrTy, false);
Constant *NewOps[] = { C0, C1 };
return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
}
if ((CE0->getOpcode() == Instruction::PtrToInt &&
CE0->getType() == IntPtrTy &&
CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType())) {
Constant *NewOps[] = {
CE0->getOperand(0), CE1->getOperand(0)
};
return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
}
}
}
}
return ConstantExpr::getCompare(Predicate, Ops[0], Ops[1]);
}
Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
ConstantExpr *CE) {
if (CE->getOperand(1) != Constant::getNullValue(CE->getOperand(1)->getType()))
return 0;
gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE);
for (++I; I != E; ++I)
if (const StructType *STy = dyn_cast<StructType>(*I)) {
ConstantInt *CU = cast<ConstantInt>(I.getOperand());
assert(CU->getZExtValue() < STy->getNumElements() &&
"Struct index out of range!");
unsigned El = (unsigned)CU->getZExtValue();
if (ConstantStruct *CS = dyn_cast<ConstantStruct>(C)) {
C = CS->getOperand(El);
} else if (isa<ConstantAggregateZero>(C)) {
C = Constant::getNullValue(STy->getElementType(El));
} else if (isa<UndefValue>(C)) {
C = UndefValue::get(STy->getElementType(El));
} else {
return 0;
}
} else if (ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand())) {
if (const ArrayType *ATy = dyn_cast<ArrayType>(*I)) {
if (CI->getZExtValue() >= ATy->getNumElements())
return 0;
if (ConstantArray *CA = dyn_cast<ConstantArray>(C))
C = CA->getOperand(CI->getZExtValue());
else if (isa<ConstantAggregateZero>(C))
C = Constant::getNullValue(ATy->getElementType());
else if (isa<UndefValue>(C))
C = UndefValue::get(ATy->getElementType());
else
return 0;
} else if (const VectorType *PTy = dyn_cast<VectorType>(*I)) {
if (CI->getZExtValue() >= PTy->getNumElements())
return 0;
if (ConstantVector *CP = dyn_cast<ConstantVector>(C))
C = CP->getOperand(CI->getZExtValue());
else if (isa<ConstantAggregateZero>(C))
C = Constant::getNullValue(PTy->getElementType());
else if (isa<UndefValue>(C))
C = UndefValue::get(PTy->getElementType());
else
return 0;
} else {
return 0;
}
} else {
return 0;
}
return C;
}
bool
llvm::canConstantFoldCallTo(const Function *F) {
switch (F->getIntrinsicID()) {
case Intrinsic::sqrt:
case Intrinsic::powi:
case Intrinsic::bswap:
case Intrinsic::ctpop:
case Intrinsic::ctlz:
case Intrinsic::cttz:
return true;
default: break;
}
if (!F->hasName()) return false;
const char *Str = F->getNameStart();
unsigned Len = F->getNameLen();
switch (Str[0]) {
default: return false;
case 'a':
if (Len == 4)
return !strcmp(Str, "acos") || !strcmp(Str, "asin") ||
!strcmp(Str, "atan");
else if (Len == 5)
return !strcmp(Str, "atan2");
return false;
case 'c':
if (Len == 3)
return !strcmp(Str, "cos");
else if (Len == 4)
return !strcmp(Str, "ceil") || !strcmp(Str, "cosf") ||
!strcmp(Str, "cosh");
return false;
case 'e':
if (Len == 3)
return !strcmp(Str, "exp");
return false;
case 'f':
if (Len == 4)
return !strcmp(Str, "fabs") || !strcmp(Str, "fmod");
else if (Len == 5)
return !strcmp(Str, "floor");
return false;
break;
case 'l':
if (Len == 3 && !strcmp(Str, "log"))
return true;
if (Len == 5 && !strcmp(Str, "log10"))
return true;
return false;
case 'p':
if (Len == 3 && !strcmp(Str, "pow"))
return true;
return false;
case 's':
if (Len == 3)
return !strcmp(Str, "sin");
if (Len == 4)
return !strcmp(Str, "sinh") || !strcmp(Str, "sqrt") ||
!strcmp(Str, "sinf");
if (Len == 5)
return !strcmp(Str, "sqrtf");
return false;
case 't':
if (Len == 3 && !strcmp(Str, "tan"))
return true;
else if (Len == 4 && !strcmp(Str, "tanh"))
return true;
return false;
}
}
static Constant *ConstantFoldFP(double (*NativeFP)(double), double V,
const Type *Ty) {
errno = 0;
V = NativeFP(V);
if (errno != 0) {
errno = 0;
return 0;
}
if (Ty == Type::FloatTy)
return ConstantFP::get(APFloat((float)V));
if (Ty == Type::DoubleTy)
return ConstantFP::get(APFloat(V));
assert(0 && "Can only constant fold float/double");
return 0; }
static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
double V, double W,
const Type *Ty) {
errno = 0;
V = NativeFP(V, W);
if (errno != 0) {
errno = 0;
return 0;
}
if (Ty == Type::FloatTy)
return ConstantFP::get(APFloat((float)V));
if (Ty == Type::DoubleTy)
return ConstantFP::get(APFloat(V));
assert(0 && "Can only constant fold float/double");
return 0; }
Constant *
llvm::ConstantFoldCall(Function *F,
Constant* const* Operands, unsigned NumOperands) {
if (!F->hasName()) return 0;
const char *Str = F->getNameStart();
unsigned Len = F->getNameLen();
const Type *Ty = F->getReturnType();
if (NumOperands == 1) {
if (ConstantFP *Op = dyn_cast<ConstantFP>(Operands[0])) {
if (Ty!=Type::FloatTy && Ty!=Type::DoubleTy)
return 0;
double V = Ty==Type::FloatTy ? (double)Op->getValueAPF().convertToFloat():
Op->getValueAPF().convertToDouble();
switch (Str[0]) {
case 'a':
if (Len == 4 && !strcmp(Str, "acos"))
return ConstantFoldFP(acos, V, Ty);
else if (Len == 4 && !strcmp(Str, "asin"))
return ConstantFoldFP(asin, V, Ty);
else if (Len == 4 && !strcmp(Str, "atan"))
return ConstantFoldFP(atan, V, Ty);
break;
case 'c':
if (Len == 4 && !strcmp(Str, "ceil"))
return ConstantFoldFP(ceil, V, Ty);
else if (Len == 3 && !strcmp(Str, "cos"))
return ConstantFoldFP(cos, V, Ty);
else if (Len == 4 && !strcmp(Str, "cosh"))
return ConstantFoldFP(cosh, V, Ty);
else if (Len == 4 && !strcmp(Str, "cosf"))
return ConstantFoldFP(cos, V, Ty);
break;
case 'e':
if (Len == 3 && !strcmp(Str, "exp"))
return ConstantFoldFP(exp, V, Ty);
break;
case 'f':
if (Len == 4 && !strcmp(Str, "fabs"))
return ConstantFoldFP(fabs, V, Ty);
else if (Len == 5 && !strcmp(Str, "floor"))
return ConstantFoldFP(floor, V, Ty);
break;
case 'l':
if (Len == 3 && !strcmp(Str, "log") && V > 0)
return ConstantFoldFP(log, V, Ty);
else if (Len == 5 && !strcmp(Str, "log10") && V > 0)
return ConstantFoldFP(log10, V, Ty);
else if (!strcmp(Str, "llvm.sqrt.f32") ||
!strcmp(Str, "llvm.sqrt.f64")) {
if (V >= -0.0)
return ConstantFoldFP(sqrt, V, Ty);
else return Constant::getNullValue(Ty);
}
break;
case 's':
if (Len == 3 && !strcmp(Str, "sin"))
return ConstantFoldFP(sin, V, Ty);
else if (Len == 4 && !strcmp(Str, "sinh"))
return ConstantFoldFP(sinh, V, Ty);
else if (Len == 4 && !strcmp(Str, "sqrt") && V >= 0)
return ConstantFoldFP(sqrt, V, Ty);
else if (Len == 5 && !strcmp(Str, "sqrtf") && V >= 0)
return ConstantFoldFP(sqrt, V, Ty);
else if (Len == 4 && !strcmp(Str, "sinf"))
return ConstantFoldFP(sin, V, Ty);
break;
case 't':
if (Len == 3 && !strcmp(Str, "tan"))
return ConstantFoldFP(tan, V, Ty);
else if (Len == 4 && !strcmp(Str, "tanh"))
return ConstantFoldFP(tanh, V, Ty);
break;
default:
break;
}
} else if (ConstantInt *Op = dyn_cast<ConstantInt>(Operands[0])) {
if (Len > 11 && !memcmp(Str, "llvm.bswap", 10))
return ConstantInt::get(Op->getValue().byteSwap());
else if (Len > 11 && !memcmp(Str, "llvm.ctpop", 10))
return ConstantInt::get(Ty, Op->getValue().countPopulation());
else if (Len > 10 && !memcmp(Str, "llvm.cttz", 9))
return ConstantInt::get(Ty, Op->getValue().countTrailingZeros());
else if (Len > 10 && !memcmp(Str, "llvm.ctlz", 9))
return ConstantInt::get(Ty, Op->getValue().countLeadingZeros());
}
} else if (NumOperands == 2) {
if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) {
if (Ty!=Type::FloatTy && Ty!=Type::DoubleTy)
return 0;
double Op1V = Ty==Type::FloatTy ?
(double)Op1->getValueAPF().convertToFloat():
Op1->getValueAPF().convertToDouble();
if (ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) {
double Op2V = Ty==Type::FloatTy ?
(double)Op2->getValueAPF().convertToFloat():
Op2->getValueAPF().convertToDouble();
if (Len == 3 && !strcmp(Str, "pow")) {
return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
} else if (Len == 4 && !strcmp(Str, "fmod")) {
return ConstantFoldBinaryFP(fmod, Op1V, Op2V, Ty);
} else if (Len == 5 && !strcmp(Str, "atan2")) {
return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
}
} else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
if (!strcmp(Str, "llvm.powi.f32")) {
return ConstantFP::get(APFloat((float)std::pow((float)Op1V,
(int)Op2C->getZExtValue())));
} else if (!strcmp(Str, "llvm.powi.f64")) {
return ConstantFP::get(APFloat((double)std::pow((double)Op1V,
(int)Op2C->getZExtValue())));
}
}
}
}
return 0;
}