RaiseAllocations.cpp [plain text]
#define DEBUG_TYPE "raiseallocs"
#include "llvm/Transforms/IPO.h"
#include "llvm/Constants.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Module.h"
#include "llvm/Instructions.h"
#include "llvm/Pass.h"
#include "llvm/Support/CallSite.h"
#include "llvm/Support/Compiler.h"
#include "llvm/ADT/Statistic.h"
#include <algorithm>
using namespace llvm;
STATISTIC(NumRaised, "Number of allocations raised");
namespace {
class VISIBILITY_HIDDEN RaiseAllocations : public ModulePass {
Function *MallocFunc; Function *FreeFunc; public:
static char ID; RaiseAllocations()
: ModulePass(&ID), MallocFunc(0), FreeFunc(0) {}
void doInitialization(Module &M);
bool runOnModule(Module &M);
};
}
char RaiseAllocations::ID = 0;
static RegisterPass<RaiseAllocations>
X("raiseallocs", "Raise allocations from calls to instructions");
ModulePass *llvm::createRaiseAllocationsPass() {
return new RaiseAllocations();
}
void RaiseAllocations::doInitialization(Module &M) {
MallocFunc = M.getFunction("malloc");
if (MallocFunc) {
const FunctionType* TyWeHave = MallocFunc->getFunctionType();
const FunctionType *Malloc1Type =
FunctionType::get(PointerType::getUnqual(Type::Int8Ty),
std::vector<const Type*>(1, Type::Int64Ty), false);
if (TyWeHave != Malloc1Type) {
const FunctionType *Malloc2Type =
FunctionType::get(PointerType::getUnqual(Type::Int8Ty),
std::vector<const Type*>(1, Type::Int32Ty), false);
if (TyWeHave != Malloc2Type) {
const FunctionType *Malloc3Type =
FunctionType::get(PointerType::getUnqual(Type::Int8Ty),
std::vector<const Type*>(), true);
if (TyWeHave != Malloc3Type)
MallocFunc = 0;
}
}
}
FreeFunc = M.getFunction("free");
if (FreeFunc) {
const FunctionType* TyWeHave = FreeFunc->getFunctionType();
const FunctionType *Free1Type = FunctionType::get(Type::VoidTy,
std::vector<const Type*>(1, PointerType::getUnqual(Type::Int8Ty)), false);
if (TyWeHave != Free1Type) {
const FunctionType* Free2Type = FunctionType::get(Type::VoidTy,
std::vector<const Type*>(),true);
if (TyWeHave != Free2Type) {
const FunctionType* Free3Type = FunctionType::get(Type::Int32Ty,
std::vector<const Type*>(),true);
if (TyWeHave != Free3Type) {
FreeFunc = 0;
}
}
}
}
if (MallocFunc && !MallocFunc->isDeclaration()) MallocFunc = 0;
if (FreeFunc && !FreeFunc->isDeclaration()) FreeFunc = 0;
}
bool RaiseAllocations::runOnModule(Module &M) {
doInitialization(M);
bool Changed = false;
if (MallocFunc) {
std::vector<User*> Users(MallocFunc->use_begin(), MallocFunc->use_end());
std::vector<Value*> EqPointers; while (!Users.empty()) {
User *U = Users.back();
Users.pop_back();
if (Instruction *I = dyn_cast<Instruction>(U)) {
CallSite CS = CallSite::get(I);
if (CS.getInstruction() && !CS.arg_empty() &&
(CS.getCalledFunction() == MallocFunc ||
std::find(EqPointers.begin(), EqPointers.end(),
CS.getCalledValue()) != EqPointers.end())) {
Value *Source = *CS.arg_begin();
if (Source->getType() != Type::Int32Ty)
Source =
CastInst::CreateIntegerCast(Source, Type::Int32Ty, false,
"MallocAmtCast", I);
MallocInst *MI = new MallocInst(Type::Int8Ty, Source, "", I);
MI->takeName(I);
I->replaceAllUsesWith(MI);
if (InvokeInst *II = dyn_cast<InvokeInst>(I))
BranchInst::Create(II->getNormalDest(), I);
I->eraseFromParent();
Changed = true;
++NumRaised;
}
} else if (GlobalValue *GV = dyn_cast<GlobalValue>(U)) {
Users.insert(Users.end(), GV->use_begin(), GV->use_end());
EqPointers.push_back(GV);
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
if (CE->isCast()) {
Users.insert(Users.end(), CE->use_begin(), CE->use_end());
EqPointers.push_back(CE);
}
}
}
}
if (FreeFunc) {
std::vector<User*> Users(FreeFunc->use_begin(), FreeFunc->use_end());
std::vector<Value*> EqPointers;
while (!Users.empty()) {
User *U = Users.back();
Users.pop_back();
if (Instruction *I = dyn_cast<Instruction>(U)) {
if (isa<InvokeInst>(I))
continue;
CallSite CS = CallSite::get(I);
if (CS.getInstruction() && !CS.arg_empty() &&
(CS.getCalledFunction() == FreeFunc ||
std::find(EqPointers.begin(), EqPointers.end(),
CS.getCalledValue()) != EqPointers.end())) {
Value *Source = *CS.arg_begin();
if (!isa<PointerType>(Source->getType()))
Source = new IntToPtrInst(Source,
PointerType::getUnqual(Type::Int8Ty),
"FreePtrCast", I);
new FreeInst(Source, I);
if (InvokeInst *II = dyn_cast<InvokeInst>(I))
BranchInst::Create(II->getNormalDest(), I);
if (I->getType() != Type::VoidTy)
I->replaceAllUsesWith(UndefValue::get(I->getType()));
I->eraseFromParent();
Changed = true;
++NumRaised;
}
} else if (GlobalValue *GV = dyn_cast<GlobalValue>(U)) {
Users.insert(Users.end(), GV->use_begin(), GV->use_end());
EqPointers.push_back(GV);
} else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(U)) {
if (CE->isCast()) {
Users.insert(Users.end(), CE->use_begin(), CE->use_end());
EqPointers.push_back(CE);
}
}
}
}
return Changed;
}