NVPTXLowerStructArgs.cpp [plain text]
#include "NVPTX.h"
#include "NVPTXUtilities.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/Pass.h"
using namespace llvm;
namespace llvm {
void initializeNVPTXLowerStructArgsPass(PassRegistry &);
}
class LLVM_LIBRARY_VISIBILITY NVPTXLowerStructArgs : public FunctionPass {
bool runOnFunction(Function &F) override;
void handleStructPtrArgs(Function &);
void handleParam(Argument *);
public:
static char ID; NVPTXLowerStructArgs() : FunctionPass(ID) {}
const char *getPassName() const override {
return "Copy structure (byval *) arguments to stack";
}
};
char NVPTXLowerStructArgs::ID = 1;
INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args",
"Lower structure arguments (NVPTX)", false, false)
void NVPTXLowerStructArgs::handleParam(Argument *Arg) {
Function *Func = Arg->getParent();
Instruction *FirstInst = &(Func->getEntryBlock().front());
PointerType *PType = dyn_cast<PointerType>(Arg->getType());
assert(PType && "Expecting pointer type in handleParam");
Type *StructType = PType->getElementType();
AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
Arg->replaceAllUsesWith(AllocA);
Type *CvtTypes[] = {
Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM),
Type::getInt8PtrTy(Func->getParent()->getContext(),
ADDRESS_SPACE_GENERIC)};
Function *CvtFunc = Intrinsic::getDeclaration(
Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes);
Value *BitcastArgs[] = {
new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(),
ADDRESS_SPACE_GENERIC),
Arg->getName(), FirstInst)};
CallInst *CallCVT =
CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst);
BitCastInst *BitCast = new BitCastInst(
CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM),
Arg->getName(), FirstInst);
LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst);
new StoreInst(LI, AllocA, FirstInst);
}
void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) {
for (Argument &Arg : F.args()) {
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
handleParam(&Arg);
}
}
}
bool NVPTXLowerStructArgs::runOnFunction(Function &F) {
if (!isKernelFunction(F))
return false;
handleStructPtrArgs(F);
return true;
}
FunctionPass *llvm::createNVPTXLowerStructArgsPass() {
return new NVPTXLowerStructArgs();
}