#include "AMDGPU.h"
#include "AMDGPUSubtarget.h"
#include "SIInstrInfo.h"
#include "SIMachineFunctionInfo.h"
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
using namespace llvm;
namespace {
typedef union {
struct {
unsigned VM;
unsigned EXP;
unsigned LGKM;
} Named;
unsigned Array[3];
} Counters;
typedef Counters RegCounters[512];
typedef std::pair<unsigned, unsigned> RegInterval;
class SIInsertWaits : public MachineFunctionPass {
private:
static char ID;
const SIInstrInfo *TII;
const SIRegisterInfo *TRI;
const MachineRegisterInfo *MRI;
static const Counters WaitCounts;
static const Counters ZeroCounts;
Counters WaitedOn;
Counters LastIssued;
RegCounters UsedRegs;
RegCounters DefinedRegs;
unsigned ExpInstrTypesSeen;
Counters getHwCounts(MachineInstr &MI);
bool isOpRelevant(MachineOperand &Op);
RegInterval getRegInterval(MachineOperand &Op);
void pushInstruction(MachineInstr &MI);
bool insertWait(MachineBasicBlock &MBB,
MachineBasicBlock::iterator I,
const Counters &Counts);
bool unorderedDefines(MachineInstr &MI);
Counters handleOperands(MachineInstr &MI);
public:
SIInsertWaits(TargetMachine &tm) :
MachineFunctionPass(ID),
TII(nullptr),
TRI(nullptr),
ExpInstrTypesSeen(0) { }
bool runOnMachineFunction(MachineFunction &MF) override;
const char *getPassName() const override {
return "SI insert wait instructions";
}
};
}
char SIInsertWaits::ID = 0;
const Counters SIInsertWaits::WaitCounts = { { 15, 7, 7 } };
const Counters SIInsertWaits::ZeroCounts = { { 0, 0, 0 } };
FunctionPass *llvm::createSIInsertWaits(TargetMachine &tm) {
return new SIInsertWaits(tm);
}
Counters SIInsertWaits::getHwCounts(MachineInstr &MI) {
uint64_t TSFlags = TII->get(MI.getOpcode()).TSFlags;
Counters Result;
Result.Named.VM = !!(TSFlags & SIInstrFlags::VM_CNT);
Result.Named.EXP = !!(TSFlags & SIInstrFlags::EXP_CNT &&
(MI.getOpcode() == AMDGPU::EXP || MI.getDesc().mayStore()));
if (TSFlags & SIInstrFlags::LGKM_CNT) {
if (TII->isSMRD(MI.getOpcode())) {
MachineOperand &Op = MI.getOperand(0);
assert(Op.isReg() && "First LGKM operand must be a register!");
unsigned Reg = Op.getReg();
unsigned Size = TRI->getMinimalPhysRegClass(Reg)->getSize();
Result.Named.LGKM = Size > 4 ? 2 : 1;
} else {
Result.Named.LGKM = 1;
}
} else {
Result.Named.LGKM = 0;
}
return Result;
}
bool SIInsertWaits::isOpRelevant(MachineOperand &Op) {
if (!Op.isReg())
return false;
if (Op.isDef())
return true;
MachineInstr &MI = *Op.getParent();
if (MI.getOpcode() == AMDGPU::EXP)
return true;
if (!MI.getDesc().mayStore())
return false;
for (MachineInstr::mop_iterator I = MI.operands_begin(),
E = MI.operands_end(); I != E; ++I) {
if (I->isReg() && I->isUse())
return Op.isIdenticalTo(*I);
}
return false;
}
RegInterval SIInsertWaits::getRegInterval(MachineOperand &Op) {
if (!Op.isReg() || !TRI->isInAllocatableClass(Op.getReg()))
return std::make_pair(0, 0);
unsigned Reg = Op.getReg();
unsigned Size = TRI->getMinimalPhysRegClass(Reg)->getSize();
assert(Size >= 4);
RegInterval Result;
Result.first = TRI->getEncodingValue(Reg);
Result.second = Result.first + Size / 4;
return Result;
}
void SIInsertWaits::pushInstruction(MachineInstr &MI) {
Counters Increment = getHwCounts(MI);
unsigned Sum = 0;
for (unsigned i = 0; i < 3; ++i) {
LastIssued.Array[i] += Increment.Array[i];
Sum += Increment.Array[i];
}
if (Sum == 0)
return;
if (Increment.Named.EXP) {
ExpInstrTypesSeen |= MI.getOpcode() == AMDGPU::EXP ? 1 : 2;
}
for (unsigned i = 0, e = MI.getNumOperands(); i != e; ++i) {
MachineOperand &Op = MI.getOperand(i);
if (!isOpRelevant(Op))
continue;
RegInterval Interval = getRegInterval(Op);
for (unsigned j = Interval.first; j < Interval.second; ++j) {
if (Op.isDef())
DefinedRegs[j] = LastIssued;
if (Op.isUse())
UsedRegs[j] = LastIssued;
}
}
}
bool SIInsertWaits::insertWait(MachineBasicBlock &MBB,
MachineBasicBlock::iterator I,
const Counters &Required) {
if (I != MBB.end() && I->getOpcode() == AMDGPU::S_ENDPGM)
return false;
bool Ordered[3];
Ordered[0] = true;
Ordered[1] = ExpInstrTypesSeen == 3;
Ordered[2] = false;
Counters Counts = WaitCounts;
bool NeedWait = false;
for (unsigned i = 0; i < 3; ++i) {
if (Required.Array[i] <= WaitedOn.Array[i])
continue;
NeedWait = true;
if (Ordered[i]) {
unsigned Value = LastIssued.Array[i] - Required.Array[i];
Counts.Array[i] = std::min(Value, WaitCounts.Array[i]);
} else
Counts.Array[i] = 0;
WaitedOn.Array[i] = LastIssued.Array[i] - Counts.Array[i];
}
if (!NeedWait)
return false;
if (Counts.Named.EXP == 0)
ExpInstrTypesSeen = 0;
BuildMI(MBB, I, DebugLoc(), TII->get(AMDGPU::S_WAITCNT))
.addImm((Counts.Named.VM & 0xF) |
((Counts.Named.EXP & 0x7) << 4) |
((Counts.Named.LGKM & 0x7) << 8));
return true;
}
static void increaseCounters(Counters &Dst, const Counters &Src) {
for (unsigned i = 0; i < 3; ++i)
Dst.Array[i] = std::max(Dst.Array[i], Src.Array[i]);
}
Counters SIInsertWaits::handleOperands(MachineInstr &MI) {
Counters Result = ZeroCounts;
if (MI.getOpcode() == AMDGPU::S_SENDMSG)
return LastIssued;
for (unsigned i = 0, e = MI.getNumOperands(); i != e; ++i) {
MachineOperand &Op = MI.getOperand(i);
RegInterval Interval = getRegInterval(Op);
for (unsigned j = Interval.first; j < Interval.second; ++j) {
if (Op.isDef()) {
increaseCounters(Result, UsedRegs[j]);
increaseCounters(Result, DefinedRegs[j]);
}
if (Op.isUse())
increaseCounters(Result, DefinedRegs[j]);
}
}
return Result;
}
bool SIInsertWaits::runOnMachineFunction(MachineFunction &MF) {
bool Changes = false;
TII = static_cast<const SIInstrInfo *>(MF.getSubtarget().getInstrInfo());
TRI =
static_cast<const SIRegisterInfo *>(MF.getSubtarget().getRegisterInfo());
MRI = &MF.getRegInfo();
WaitedOn = ZeroCounts;
LastIssued = ZeroCounts;
memset(&UsedRegs, 0, sizeof(UsedRegs));
memset(&DefinedRegs, 0, sizeof(DefinedRegs));
for (MachineFunction::iterator BI = MF.begin(), BE = MF.end();
BI != BE; ++BI) {
MachineBasicBlock &MBB = *BI;
for (MachineBasicBlock::iterator I = MBB.begin(), E = MBB.end();
I != E; ++I) {
Changes |= insertWait(MBB, I, handleOperands(*I));
pushInstruction(*I);
}
Changes |= insertWait(MBB, MBB.getFirstTerminator(), LastIssued);
}
return Changes;
}