AArch64BranchFixupPass.cpp [plain text]
#define DEBUG_TYPE "aarch64-branch-fixup"
#include "AArch64.h"
#include "AArch64InstrInfo.h"
#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Format.h"
#include "llvm/Support/raw_ostream.h"
using namespace llvm;
STATISTIC(NumSplit, "Number of uncond branches inserted");
STATISTIC(NumCBrFixed, "Number of cond branches fixed");
static inline unsigned UnknownPadding(unsigned LogAlign, unsigned KnownBits) {
if (KnownBits < LogAlign)
return (1u << LogAlign) - (1u << KnownBits);
return 0;
}
namespace {
class AArch64BranchFixup : public MachineFunctionPass {
struct BasicBlockInfo {
unsigned Offset;
unsigned Size;
uint8_t KnownBits;
uint8_t Unalign;
BasicBlockInfo() : Offset(0), Size(0), KnownBits(0), Unalign(0) {}
unsigned internalKnownBits() const {
unsigned Bits = Unalign ? Unalign : KnownBits;
if (Size & ((1u << Bits) - 1))
Bits = countTrailingZeros(Size);
return Bits;
}
unsigned postOffset(unsigned LogAlign = 0) const {
unsigned PO = Offset + Size;
if (!LogAlign)
return PO;
return PO + UnknownPadding(LogAlign, internalKnownBits());
}
unsigned postKnownBits(unsigned LogAlign = 0) const {
return std::max(LogAlign, internalKnownBits());
}
};
std::vector<BasicBlockInfo> BBInfo;
struct ImmBranch {
MachineInstr *MI;
unsigned OffsetBits : 31;
bool IsCond : 1;
ImmBranch(MachineInstr *mi, unsigned offsetbits, bool cond)
: MI(mi), OffsetBits(offsetbits), IsCond(cond) {}
};
std::vector<ImmBranch> ImmBranches;
MachineFunction *MF;
const AArch64InstrInfo *TII;
public:
static char ID;
AArch64BranchFixup() : MachineFunctionPass(ID) {}
virtual bool runOnMachineFunction(MachineFunction &MF);
virtual const char *getPassName() const {
return "AArch64 branch fixup pass";
}
private:
void initializeFunctionInfo();
MachineBasicBlock *splitBlockBeforeInstr(MachineInstr *MI);
void adjustBBOffsetsAfter(MachineBasicBlock *BB);
bool isBBInRange(MachineInstr *MI, MachineBasicBlock *BB,
unsigned OffsetBits);
bool fixupImmediateBr(ImmBranch &Br);
bool fixupConditionalBr(ImmBranch &Br);
void computeBlockSize(MachineBasicBlock *MBB);
unsigned getOffsetOf(MachineInstr *MI) const;
void dumpBBs();
void verify();
};
char AArch64BranchFixup::ID = 0;
}
void AArch64BranchFixup::verify() {
#ifndef NDEBUG
for (MachineFunction::iterator MBBI = MF->begin(), E = MF->end();
MBBI != E; ++MBBI) {
MachineBasicBlock *MBB = MBBI;
unsigned MBBId = MBB->getNumber();
assert(!MBBId || BBInfo[MBBId - 1].postOffset() <= BBInfo[MBBId].Offset);
}
#endif
}
void AArch64BranchFixup::dumpBBs() {
DEBUG({
for (unsigned J = 0, E = BBInfo.size(); J !=E; ++J) {
const BasicBlockInfo &BBI = BBInfo[J];
dbgs() << format("%08x BB#%u\t", BBI.Offset, J)
<< " kb=" << unsigned(BBI.KnownBits)
<< " ua=" << unsigned(BBI.Unalign)
<< format(" size=%#x\n", BBInfo[J].Size);
}
});
}
FunctionPass *llvm::createAArch64BranchFixupPass() {
return new AArch64BranchFixup();
}
bool AArch64BranchFixup::runOnMachineFunction(MachineFunction &mf) {
MF = &mf;
DEBUG(dbgs() << "***** AArch64BranchFixup ******");
TII = (const AArch64InstrInfo*)MF->getTarget().getInstrInfo();
MF->getRegInfo().invalidateLiveness();
MF->RenumberBlocks();
initializeFunctionInfo();
unsigned NoBRIters = 0;
bool MadeChange = false;
while (true) {
DEBUG(dbgs() << "Beginning iteration #" << NoBRIters << '\n');
bool BRChange = false;
for (unsigned i = 0, e = ImmBranches.size(); i != e; ++i)
BRChange |= fixupImmediateBr(ImmBranches[i]);
if (BRChange && ++NoBRIters > 30)
report_fatal_error("Branch Fix Up pass failed to converge!");
DEBUG(dumpBBs());
if (!BRChange)
break;
MadeChange = true;
}
verify();
DEBUG(dbgs() << '\n'; dumpBBs());
BBInfo.clear();
ImmBranches.clear();
return MadeChange;
}
static bool BBHasFallthrough(MachineBasicBlock *MBB) {
MachineFunction::iterator MBBI = MBB;
if (llvm::next(MBBI) == MBB->getParent()->end())
return false;
MachineBasicBlock *NextBB = llvm::next(MBBI);
for (MachineBasicBlock::succ_iterator I = MBB->succ_begin(),
E = MBB->succ_end(); I != E; ++I)
if (*I == NextBB)
return true;
return false;
}
void AArch64BranchFixup::initializeFunctionInfo() {
BBInfo.clear();
BBInfo.resize(MF->getNumBlockIDs());
for (MachineFunction::iterator I = MF->begin(), E = MF->end(); I != E; ++I)
computeBlockSize(I);
BBInfo.front().KnownBits = MF->getAlignment();
adjustBBOffsetsAfter(MF->begin());
for (MachineFunction::iterator MBBI = MF->begin(), E = MF->end();
MBBI != E; ++MBBI) {
MachineBasicBlock &MBB = *MBBI;
for (MachineBasicBlock::iterator I = MBB.begin(), E = MBB.end();
I != E; ++I) {
if (I->isDebugValue())
continue;
int Opc = I->getOpcode();
if (I->isBranch()) {
bool IsCond = false;
unsigned Bits = 0;
switch (Opc) {
default:
continue; case AArch64::TBZxii:
case AArch64::TBZwii:
case AArch64::TBNZxii:
case AArch64::TBNZwii:
IsCond = true;
Bits = 14 + 2;
break;
case AArch64::Bcc:
case AArch64::CBZx:
case AArch64::CBZw:
case AArch64::CBNZx:
case AArch64::CBNZw:
IsCond = true;
Bits = 19 + 2;
break;
case AArch64::Bimm:
Bits = 26 + 2;
break;
}
ImmBranches.push_back(ImmBranch(I, Bits, IsCond));
}
}
}
}
void AArch64BranchFixup::computeBlockSize(MachineBasicBlock *MBB) {
BasicBlockInfo &BBI = BBInfo[MBB->getNumber()];
BBI.Size = 0;
BBI.Unalign = 0;
for (MachineBasicBlock::iterator I = MBB->begin(), E = MBB->end(); I != E;
++I) {
BBI.Size += TII->getInstSizeInBytes(*I);
if (I->isInlineAsm())
BBI.Unalign = 2;
}
}
unsigned AArch64BranchFixup::getOffsetOf(MachineInstr *MI) const {
MachineBasicBlock *MBB = MI->getParent();
unsigned Offset = BBInfo[MBB->getNumber()].Offset;
for (MachineBasicBlock::iterator I = MBB->begin(); &*I != MI; ++I) {
assert(I != MBB->end() && "Didn't find MI in its own basic block?");
Offset += TII->getInstSizeInBytes(*I);
}
return Offset;
}
MachineBasicBlock *
AArch64BranchFixup::splitBlockBeforeInstr(MachineInstr *MI) {
MachineBasicBlock *OrigBB = MI->getParent();
MachineBasicBlock *NewBB =
MF->CreateMachineBasicBlock(OrigBB->getBasicBlock());
MachineFunction::iterator MBBI = OrigBB; ++MBBI;
MF->insert(MBBI, NewBB);
NewBB->splice(NewBB->end(), OrigBB, MI, OrigBB->end());
BuildMI(OrigBB, DebugLoc(), TII->get(AArch64::Bimm)).addMBB(NewBB);
++NumSplit;
NewBB->transferSuccessors(OrigBB);
OrigBB->addSuccessor(NewBB);
MF->RenumberBlocks(NewBB);
BBInfo.insert(BBInfo.begin() + NewBB->getNumber(), BasicBlockInfo());
computeBlockSize(OrigBB);
computeBlockSize(NewBB);
adjustBBOffsetsAfter(OrigBB);
return NewBB;
}
void AArch64BranchFixup::adjustBBOffsetsAfter(MachineBasicBlock *BB) {
unsigned BBNum = BB->getNumber();
for(unsigned i = BBNum + 1, e = MF->getNumBlockIDs(); i < e; ++i) {
unsigned LogAlign = MF->getBlockNumbered(i)->getAlignment();
unsigned Offset = BBInfo[i - 1].postOffset(LogAlign);
unsigned KnownBits = BBInfo[i - 1].postKnownBits(LogAlign);
if (i > BBNum + 2 &&
BBInfo[i].Offset == Offset &&
BBInfo[i].KnownBits == KnownBits)
break;
BBInfo[i].Offset = Offset;
BBInfo[i].KnownBits = KnownBits;
}
}
bool AArch64BranchFixup::isBBInRange(MachineInstr *MI,
MachineBasicBlock *DestBB,
unsigned OffsetBits) {
int64_t BrOffset = getOffsetOf(MI);
int64_t DestOffset = BBInfo[DestBB->getNumber()].Offset;
DEBUG(dbgs() << "Branch of destination BB#" << DestBB->getNumber()
<< " from BB#" << MI->getParent()->getNumber()
<< " bits available=" << OffsetBits
<< " from " << getOffsetOf(MI) << " to " << DestOffset
<< " offset " << int(DestOffset-BrOffset) << "\t" << *MI);
return isIntN(OffsetBits, DestOffset - BrOffset);
}
bool AArch64BranchFixup::fixupImmediateBr(ImmBranch &Br) {
MachineInstr *MI = Br.MI;
MachineBasicBlock *DestBB = 0;
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
if (MI->getOperand(i).isMBB()) {
DestBB = MI->getOperand(i).getMBB();
break;
}
}
assert(DestBB && "Branch with no destination BB?");
if (isBBInRange(MI, DestBB, Br.OffsetBits))
return false;
assert(Br.IsCond && "Only conditional branches should need fixup");
return fixupConditionalBr(Br);
}
bool
AArch64BranchFixup::fixupConditionalBr(ImmBranch &Br) {
MachineInstr *MI = Br.MI;
MachineBasicBlock *MBB = MI->getParent();
unsigned CondBrMBBOperand = 0;
if (MI->getOpcode() == AArch64::Bcc) {
CondBrMBBOperand = 1;
A64CC::CondCodes CC = (A64CC::CondCodes)MI->getOperand(0).getImm();
CC = A64InvertCondCode(CC);
MI->getOperand(0).setImm(CC);
} else {
MachineInstrBuilder InvertedMI;
int InvertedOpcode;
switch (MI->getOpcode()) {
default: llvm_unreachable("Unknown branch type");
case AArch64::TBZxii: InvertedOpcode = AArch64::TBNZxii; break;
case AArch64::TBZwii: InvertedOpcode = AArch64::TBNZwii; break;
case AArch64::TBNZxii: InvertedOpcode = AArch64::TBZxii; break;
case AArch64::TBNZwii: InvertedOpcode = AArch64::TBZwii; break;
case AArch64::CBZx: InvertedOpcode = AArch64::CBNZx; break;
case AArch64::CBZw: InvertedOpcode = AArch64::CBNZw; break;
case AArch64::CBNZx: InvertedOpcode = AArch64::CBZx; break;
case AArch64::CBNZw: InvertedOpcode = AArch64::CBZw; break;
}
InvertedMI = BuildMI(*MBB, MI, MI->getDebugLoc(), TII->get(InvertedOpcode));
for (unsigned i = 0, e= MI->getNumOperands(); i != e; ++i) {
InvertedMI.addOperand(MI->getOperand(i));
if (MI->getOperand(i).isMBB())
CondBrMBBOperand = i;
}
MI->eraseFromParent();
MI = Br.MI = InvertedMI;
}
MachineInstr *BMI = &MBB->back();
bool NeedSplit = (BMI != MI) || !BBHasFallthrough(MBB);
++NumCBrFixed;
if (BMI != MI) {
if (llvm::next(MachineBasicBlock::iterator(MI)) == prior(MBB->end()) &&
BMI->getOpcode() == AArch64::Bimm) {
MachineBasicBlock *NewDest = BMI->getOperand(0).getMBB();
if (isBBInRange(MI, NewDest, Br.OffsetBits)) {
DEBUG(dbgs() << " Invert Bcc condition and swap its destination with "
<< *BMI);
MachineBasicBlock *DestBB = MI->getOperand(CondBrMBBOperand).getMBB();
BMI->getOperand(0).setMBB(DestBB);
MI->getOperand(CondBrMBBOperand).setMBB(NewDest);
return true;
}
}
}
if (NeedSplit) {
MachineBasicBlock::iterator MBBI = MI; ++MBBI;
splitBlockBeforeInstr(MBBI);
int delta = TII->getInstSizeInBytes(MBB->back());
BBInfo[MBB->getNumber()].Size -= delta;
MBB->back().eraseFromParent();
}
MachineBasicBlock *NextBB = llvm::next(MachineFunction::iterator(MBB));
DEBUG(dbgs() << " Insert B to BB#"
<< MI->getOperand(CondBrMBBOperand).getMBB()->getNumber()
<< " also invert condition and change dest. to BB#"
<< NextBB->getNumber() << "\n");
BuildMI(MBB, DebugLoc(), TII->get(AArch64::Bimm))
.addMBB(MI->getOperand(CondBrMBBOperand).getMBB());
MI->getOperand(CondBrMBBOperand).setMBB(NextBB);
BBInfo[MBB->getNumber()].Size += TII->getInstSizeInBytes(MBB->back());
unsigned OffsetBits = 26 + 2;
ImmBranches.push_back(ImmBranch(&MBB->back(), OffsetBits, false));
adjustBBOffsetsAfter(MBB);
return true;
}