ImplicitNullChecks.cpp [plain text]
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/MachineOperand.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Instruction.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetSubtargetInfo.h"
#include "llvm/Target/TargetInstrInfo.h"
using namespace llvm;
static cl::opt<unsigned> PageSize("imp-null-check-page-size",
cl::desc("The page size of the target in "
"bytes"),
cl::init(4096));
#define DEBUG_TYPE "implicit-null-checks"
STATISTIC(NumImplicitNullChecks,
"Number of explicit null checks made implicit");
namespace {
class ImplicitNullChecks : public MachineFunctionPass {
struct NullCheck {
MachineInstr *MemOperation;
MachineInstr *CheckOperation;
MachineBasicBlock *CheckBlock;
MachineBasicBlock *NotNullSucc;
MachineBasicBlock *NullSucc;
NullCheck()
: MemOperation(), CheckOperation(), CheckBlock(), NotNullSucc(),
NullSucc() {}
explicit NullCheck(MachineInstr *memOperation, MachineInstr *checkOperation,
MachineBasicBlock *checkBlock,
MachineBasicBlock *notNullSucc,
MachineBasicBlock *nullSucc)
: MemOperation(memOperation), CheckOperation(checkOperation),
CheckBlock(checkBlock), NotNullSucc(notNullSucc), NullSucc(nullSucc) {
}
};
const TargetInstrInfo *TII = nullptr;
const TargetRegisterInfo *TRI = nullptr;
MachineModuleInfo *MMI = nullptr;
bool analyzeBlockForNullChecks(MachineBasicBlock &MBB,
SmallVectorImpl<NullCheck> &NullCheckList);
MachineInstr *insertFaultingLoad(MachineInstr *LoadMI, MachineBasicBlock *MBB,
MCSymbol *HandlerLabel);
void rewriteNullChecks(ArrayRef<NullCheck> NullCheckList);
public:
static char ID;
ImplicitNullChecks() : MachineFunctionPass(ID) {
initializeImplicitNullChecksPass(*PassRegistry::getPassRegistry());
}
bool runOnMachineFunction(MachineFunction &MF) override;
};
}
bool ImplicitNullChecks::runOnMachineFunction(MachineFunction &MF) {
TII = MF.getSubtarget().getInstrInfo();
TRI = MF.getRegInfo().getTargetRegisterInfo();
MMI = &MF.getMMI();
SmallVector<NullCheck, 16> NullCheckList;
for (auto &MBB : MF)
analyzeBlockForNullChecks(MBB, NullCheckList);
if (!NullCheckList.empty())
rewriteNullChecks(NullCheckList);
return !NullCheckList.empty();
}
bool ImplicitNullChecks::analyzeBlockForNullChecks(
MachineBasicBlock &MBB, SmallVectorImpl<NullCheck> &NullCheckList) {
typedef TargetInstrInfo::MachineBranchPredicate MachineBranchPredicate;
MDNode *BranchMD =
MBB.getBasicBlock()
? MBB.getBasicBlock()->getTerminator()->getMetadata(LLVMContext::MD_make_implicit)
: nullptr;
if (!BranchMD)
return false;
MachineBranchPredicate MBP;
if (TII->AnalyzeBranchPredicate(MBB, MBP, true))
return false;
if (!(MBP.LHS.isReg() && MBP.RHS.isImm() && MBP.RHS.getImm() == 0 &&
(MBP.Predicate == MachineBranchPredicate::PRED_NE ||
MBP.Predicate == MachineBranchPredicate::PRED_EQ)))
return false;
if (!MBP.SingleUseCondition)
return false;
MachineBasicBlock *NotNullSucc, *NullSucc;
if (MBP.Predicate == MachineBranchPredicate::PRED_NE) {
NotNullSucc = MBP.TrueDest;
NullSucc = MBP.FalseDest;
} else {
NotNullSucc = MBP.FalseDest;
NullSucc = MBP.TrueDest;
}
if (NotNullSucc->pred_size() != 1)
return false;
unsigned PointerReg = MBP.LHS.getReg();
DenseSet<unsigned> RegDefs, RegUses;
auto IsSafeToHoist = [&](MachineInstr *MI) {
for (auto *MMO : MI->memoperands())
if (!MMO->isUnordered())
return false;
for (auto &MO : MI->operands()) {
if (MO.isReg() && MO.getReg()) {
for (unsigned Reg : RegDefs)
if (TRI->regsOverlap(Reg, MO.getReg()))
return false;
if (MO.isDef())
for (unsigned Reg : RegUses)
if (TRI->regsOverlap(Reg, MO.getReg()))
return false; }
}
return true;
};
for (auto MII = NotNullSucc->begin(), MIE = NotNullSucc->end(); MII != MIE;
++MII) {
MachineInstr *MI = &*MII;
unsigned BaseReg, Offset;
if (TII->getMemOpBaseRegImmOfs(MI, BaseReg, Offset, TRI))
if (MI->mayLoad() && !MI->isPredicable() && BaseReg == PointerReg &&
Offset < PageSize && MI->getDesc().getNumDefs() <= 1 &&
IsSafeToHoist(MI)) {
NullCheckList.emplace_back(MI, MBP.ConditionDef, &MBB, NotNullSucc,
NullSucc);
return true;
}
if (MI->mayStore() || MI->hasUnmodeledSideEffects())
return false;
for (auto *MMO : MI->memoperands())
if (!MMO->isUnordered())
return false;
for (auto &MO : MI->operands()) {
if (!MO.isReg() || !MO.getReg())
continue;
if (MO.isDef())
RegDefs.insert(MO.getReg());
else
RegUses.insert(MO.getReg());
}
}
return false;
}
MachineInstr *ImplicitNullChecks::insertFaultingLoad(MachineInstr *LoadMI,
MachineBasicBlock *MBB,
MCSymbol *HandlerLabel) {
const unsigned NoRegister = 0;
DebugLoc DL;
unsigned NumDefs = LoadMI->getDesc().getNumDefs();
assert(NumDefs <= 1 && "other cases unhandled!");
unsigned DefReg = NoRegister;
if (NumDefs != 0) {
DefReg = LoadMI->defs().begin()->getReg();
assert(std::distance(LoadMI->defs().begin(), LoadMI->defs().end()) == 1 &&
"expected exactly one def!");
}
auto MIB = BuildMI(MBB, DL, TII->get(TargetOpcode::FAULTING_LOAD_OP), DefReg)
.addSym(HandlerLabel)
.addImm(LoadMI->getOpcode());
for (auto &MO : LoadMI->uses())
MIB.addOperand(MO);
MIB.setMemRefs(LoadMI->memoperands_begin(), LoadMI->memoperands_end());
return MIB;
}
void ImplicitNullChecks::rewriteNullChecks(
ArrayRef<ImplicitNullChecks::NullCheck> NullCheckList) {
DebugLoc DL;
for (auto &NC : NullCheckList) {
MCSymbol *HandlerLabel = MMI->getContext().createTempSymbol();
unsigned BranchesRemoved = TII->RemoveBranch(*NC.CheckBlock);
(void)BranchesRemoved;
assert(BranchesRemoved > 0 && "expected at least one branch!");
insertFaultingLoad(NC.MemOperation, NC.CheckBlock, HandlerLabel);
NC.MemOperation->eraseFromParent();
NC.CheckOperation->eraseFromParent();
TII->InsertBranch(*NC.CheckBlock, NC.NotNullSucc, nullptr, None,
DL);
BuildMI(*NC.NullSucc, NC.NullSucc->begin(), DL,
TII->get(TargetOpcode::EH_LABEL)).addSym(HandlerLabel);
NumImplicitNullChecks++;
}
}
char ImplicitNullChecks::ID = 0;
char &llvm::ImplicitNullChecksID = ImplicitNullChecks::ID;
INITIALIZE_PASS_BEGIN(ImplicitNullChecks, "implicit-null-checks",
"Implicit null checks", false, false)
INITIALIZE_PASS_END(ImplicitNullChecks, "implicit-null-checks",
"Implicit null checks", false, false)