reg_alloc: Consider bitwidth of data and registers when emitting instructions

This commit is contained in:
MerryMage 2018-01-18 13:00:07 +00:00
parent 144b629d8a
commit fff8e019dc
6 changed files with 193 additions and 79 deletions

View file

@ -39,6 +39,42 @@ static bool IsSameHostLocClass(HostLoc a, HostLoc b) {
|| (HostLocIsSpill(a) && HostLocIsSpill(b));
}
// Minimum number of bits required to represent a type
static size_t GetBitWidth(IR::Type type) {
switch (type) {
case IR::Type::A32Reg:
case IR::Type::A32ExtReg:
case IR::Type::A64Reg:
case IR::Type::A64Vec:
case IR::Type::CoprocInfo:
case IR::Type::Cond:
case IR::Type::Void:
ASSERT_MSG(false, "Type %zu cannot be represented at runtime", static_cast<size_t>(type));
return 0;
case IR::Type::Opaque:
ASSERT_MSG(false, "Not a concrete type");
return 0;
case IR::Type::U1:
return 8;
case IR::Type::U8:
return 8;
case IR::Type::U16:
return 16;
case IR::Type::U32:
return 32;
case IR::Type::U64:
return 64;
case IR::Type::F32:
return 32;
case IR::Type::F64:
return 64;
case IR::Type::F128:
return 128;
case IR::Type::NZCVFlags:
return 32; // TODO: Update to 16 when flags optimization is done
}
}
bool HostLocInfo::IsLocked() const {
return is_being_used;
}
@ -51,10 +87,6 @@ bool HostLocInfo::IsLastUse() const {
return !is_being_used && current_references == 1 && accumulated_uses + 1 == total_uses;
}
bool HostLocInfo::ContainsValue(const IR::Inst* inst) const {
return std::find(values.begin(), values.end(), inst) != values.end();
}
void HostLocInfo::ReadLock() {
ASSERT(!is_scratch);
is_being_used = true;
@ -66,11 +98,6 @@ void HostLocInfo::WriteLock() {
is_scratch = true;
}
void HostLocInfo::AddValue(IR::Inst* inst) {
values.push_back(inst);
total_uses += inst->UseCount();
}
void HostLocInfo::AddArgReference() {
current_references++;
ASSERT(accumulated_uses + current_references <= total_uses);
@ -84,6 +111,7 @@ void HostLocInfo::EndOfAllocScope() {
values.clear();
accumulated_uses = 0;
total_uses = 0;
max_bit_width = 0;
}
ASSERT(total_uses == std::accumulate(values.begin(), values.end(), size_t(0), [](size_t sum, IR::Inst* inst) { return sum + inst->UseCount(); }));
@ -92,6 +120,20 @@ void HostLocInfo::EndOfAllocScope() {
is_scratch = false;
}
bool HostLocInfo::ContainsValue(const IR::Inst* inst) const {
return std::find(values.begin(), values.end(), inst) != values.end();
}
size_t HostLocInfo::GetMaxBitWidth() const {
return max_bit_width;
}
void HostLocInfo::AddValue(IR::Inst* inst) {
values.push_back(inst);
total_uses += inst->UseCount();
max_bit_width = std::max(max_bit_width, GetBitWidth(inst->GetType()));
}
IR::Type Argument::GetType() const {
return value.GetType();
}
@ -439,15 +481,16 @@ HostLoc RegAlloc::LoadImmediate(IR::Value imm, HostLoc host_loc) {
void RegAlloc::Move(HostLoc to, HostLoc from) {
ASSERT(LocInfo(to).IsEmpty() && !LocInfo(from).IsLocked());
ASSERT(LocInfo(from).GetMaxBitWidth() <= HostLocBitWidth(to));
if (LocInfo(from).IsEmpty()) {
return;
}
EmitMove(to, from);
LocInfo(to) = LocInfo(from);
LocInfo(from) = {};
EmitMove(to, from);
}
void RegAlloc::CopyToScratch(HostLoc to, HostLoc from) {
@ -458,6 +501,8 @@ void RegAlloc::CopyToScratch(HostLoc to, HostLoc from) {
void RegAlloc::Exchange(HostLoc a, HostLoc b) {
ASSERT(!LocInfo(a).IsLocked() && !LocInfo(b).IsLocked());
ASSERT(LocInfo(a).GetMaxBitWidth() <= HostLocBitWidth(b));
ASSERT(LocInfo(b).GetMaxBitWidth() <= HostLocBitWidth(a));
if (LocInfo(a).IsEmpty()) {
Move(a, b);
@ -469,9 +514,9 @@ void RegAlloc::Exchange(HostLoc a, HostLoc b) {
return;
}
std::swap(LocInfo(a), LocInfo(b));
EmitExchange(a, b);
std::swap(LocInfo(a), LocInfo(b));
}
void RegAlloc::MoveOutOfTheWay(HostLoc reg) {
@ -511,22 +556,81 @@ const HostLocInfo& RegAlloc::LocInfo(HostLoc loc) const {
}
void RegAlloc::EmitMove(HostLoc to, HostLoc from) {
const size_t bit_width = LocInfo(from).GetMaxBitWidth();
if (HostLocIsXMM(to) && HostLocIsXMM(from)) {
code->movaps(HostLocToXmm(to), HostLocToXmm(from));
} else if (HostLocIsGPR(to) && HostLocIsGPR(from)) {
code->mov(HostLocToReg64(to), HostLocToReg64(from));
ASSERT(bit_width != 128);
if (bit_width == 64) {
code->mov(HostLocToReg64(to), HostLocToReg64(from));
} else {
code->mov(HostLocToReg64(to).cvt32(), HostLocToReg64(from).cvt32());
}
} else if (HostLocIsXMM(to) && HostLocIsGPR(from)) {
code->movq(HostLocToXmm(to), HostLocToReg64(from));
ASSERT(bit_width != 128);
if (bit_width == 64) {
code->movq(HostLocToXmm(to), HostLocToReg64(from));
} else {
code->movd(HostLocToXmm(to), HostLocToReg64(from).cvt32());
}
} else if (HostLocIsGPR(to) && HostLocIsXMM(from)) {
code->movq(HostLocToReg64(to), HostLocToXmm(from));
ASSERT(bit_width != 128);
if (bit_width == 64) {
code->movq(HostLocToReg64(to), HostLocToXmm(from));
} else {
code->movd(HostLocToReg64(to).cvt32(), HostLocToXmm(from));
}
} else if (HostLocIsXMM(to) && HostLocIsSpill(from)) {
code->movsd(HostLocToXmm(to), spill_to_addr(from));
Xbyak::Address spill_addr = spill_to_addr(from);
ASSERT(spill_addr.getBit() >= bit_width);
switch (bit_width) {
case 128:
code->movaps(HostLocToXmm(to), spill_addr);
break;
case 64:
code->movsd(HostLocToXmm(to), spill_addr);
break;
case 32:
case 16:
case 8:
code->movss(HostLocToXmm(to), spill_addr);
break;
default:
UNREACHABLE();
}
} else if (HostLocIsSpill(to) && HostLocIsXMM(from)) {
code->movsd(spill_to_addr(to), HostLocToXmm(from));
Xbyak::Address spill_addr = spill_to_addr(to);
ASSERT(spill_addr.getBit() >= bit_width);
switch (bit_width) {
case 128:
code->movaps(spill_addr, HostLocToXmm(from));
break;
case 64:
code->movsd(spill_addr, HostLocToXmm(from));
break;
case 32:
case 16:
case 8:
code->movss(spill_addr, HostLocToXmm(from));
break;
default:
UNREACHABLE();
}
} else if (HostLocIsGPR(to) && HostLocIsSpill(from)) {
code->mov(HostLocToReg64(to), spill_to_addr(from));
ASSERT(bit_width != 128);
if (bit_width == 64) {
code->mov(HostLocToReg64(to), spill_to_addr(from));
} else {
code->mov(HostLocToReg64(to).cvt32(), spill_to_addr(from));
}
} else if (HostLocIsSpill(to) && HostLocIsGPR(from)) {
code->mov(spill_to_addr(to), HostLocToReg64(from));
ASSERT(bit_width != 128);
if (bit_width == 64) {
code->mov(spill_to_addr(to), HostLocToReg64(from));
} else {
code->mov(spill_to_addr(to), HostLocToReg64(from).cvt32());
}
} else {
ASSERT_MSG(false, "Invalid RegAlloc::EmitMove");
}