Skip to content

Commit

Permalink
fix SSA pass. Keep Identity (still useful). Split into 2 ReplaceUses …
Browse files Browse the repository at this point in the history
…variants
  • Loading branch information
baggins183 committed Oct 22, 2024
1 parent 2b9db86 commit 4773b19
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 32 deletions.
25 changes: 13 additions & 12 deletions src/shader_recompiler/ir/microinstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,21 +177,22 @@ void Inst::ClearArgs() {
}
}

void Inst::ReplaceUsesWith(Value replacement) {
// move uses because SetArg will call UndoUse and would otherwise
// mutate uses while iterating
#ifdef _DEBUG
void Inst::ReplaceUsesWith(Value replacement, bool preserve) {
// Could also do temp_uses = std::move(uses)
// But clearer this way
// Copy since user->SetArg will mutate this->uses
boost::container::list<IR::Use> temp_uses = uses;
#else
boost::container::list<IR::Use> temp_uses = std::move(uses);
#endif
if (!replacement.IsImmediate()) {
for (auto& [user, operand] : temp_uses) {
DEBUG_ASSERT(user->Arg(operand).Inst() == this);
user->SetArg(operand, replacement);
}
for (auto& [user, operand] : temp_uses) {
DEBUG_ASSERT(user->Arg(operand).Inst() == this);
user->SetArg(operand, replacement);
}
Invalidate();
if (preserve) {
// Still useful to have Identity for indirection.
// SSA pass would be more complicated without it
ReplaceOpcode(Opcode::Identity);
SetArg(0, replacement);
}
}

void Inst::ReplaceOpcode(IR::Opcode opcode) {
Expand Down
35 changes: 18 additions & 17 deletions src/shader_recompiler/ir/passes/constant_propagation_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {

if (is_lhs_immediate && is_rhs_immediate) {
const auto result{imm_fn(Arg<T>(lhs), Arg<T>(rhs))};
inst.ReplaceUsesWith(IR::Value{result});
inst.ReplaceUsesWithAndRemove(IR::Value{result});
return false;
}
if (is_lhs_immediate && !is_rhs_immediate) {
Expand Down Expand Up @@ -75,20 +75,20 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
return false;
}
using Indices = std::make_index_sequence<Common::LambdaTraits<decltype(func)>::NUM_ARGS>;
inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
inst.ReplaceUsesWithAndRemove(EvalImmediates(inst, func, Indices{}));
return true;
}

template <IR::Opcode op, typename Dest, typename Source>
void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
const IR::Value value{inst.Arg(0)};
if (value.IsImmediate()) {
inst.ReplaceUsesWith(IR::Value{std::bit_cast<Dest>(Arg<Source>(value))});
inst.ReplaceUsesWithAndRemove(IR::Value{std::bit_cast<Dest>(Arg<Source>(value))});
return;
}
IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == reverse) {
inst.ReplaceUsesWith(arg_inst->Arg(0));
inst.ReplaceUsesWithAndRemove(arg_inst->Arg(0));
return;
}
}
Expand Down Expand Up @@ -131,7 +131,7 @@ void FoldCompositeExtract(IR::Inst& inst, IR::Opcode construct, IR::Opcode inser
if (!result) {
return;
}
inst.ReplaceUsesWith(*result);
inst.ReplaceUsesWithAndRemove(*result);
}

void FoldConvert(IR::Inst& inst, IR::Opcode opposite) {
Expand All @@ -141,7 +141,7 @@ void FoldConvert(IR::Inst& inst, IR::Opcode opposite) {
}
IR::Inst* const producer{value.InstRecursive()};
if (producer->GetOpcode() == opposite) {
inst.ReplaceUsesWith(producer->Arg(0));
inst.ReplaceUsesWithAndRemove(producer->Arg(0));
}
}

Expand All @@ -152,17 +152,17 @@ void FoldLogicalAnd(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate()) {
if (rhs.U1()) {
inst.ReplaceUsesWith(inst.Arg(0));
inst.ReplaceUsesWithAndRemove(inst.Arg(0));
} else {
inst.ReplaceUsesWith(IR::Value{false});
inst.ReplaceUsesWithAndRemove(IR::Value{false});
}
}
}

void FoldSelect(IR::Inst& inst) {
const IR::Value cond{inst.Arg(0)};
if (cond.IsImmediate()) {
inst.ReplaceUsesWith(cond.U1() ? inst.Arg(1) : inst.Arg(2));
inst.ReplaceUsesWithAndRemove(cond.U1() ? inst.Arg(1) : inst.Arg(2));
}
}

Expand All @@ -173,22 +173,22 @@ void FoldLogicalOr(IR::Inst& inst) {
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate()) {
if (rhs.U1()) {
inst.ReplaceUsesWith(IR::Value{true});
inst.ReplaceUsesWithAndRemove(IR::Value{true});
} else {
inst.ReplaceUsesWith(inst.Arg(0));
inst.ReplaceUsesWithAndRemove(inst.Arg(0));
}
}
}

void FoldLogicalNot(IR::Inst& inst) {
const IR::U1 value{inst.Arg(0)};
if (value.IsImmediate()) {
inst.ReplaceUsesWith(IR::Value{!value.U1()});
inst.ReplaceUsesWithAndRemove(IR::Value{!value.U1()});
return;
}
IR::Inst* const arg{value.InstRecursive()};
if (arg->GetOpcode() == IR::Opcode::LogicalNot) {
inst.ReplaceUsesWith(arg->Arg(0));
inst.ReplaceUsesWithAndRemove(arg->Arg(0));
}
}

Expand All @@ -199,7 +199,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
}
IR::Inst* const arg_inst{value.InstRecursive()};
if (arg_inst->GetOpcode() == reverse) {
inst.ReplaceUsesWith(arg_inst->Arg(0));
inst.ReplaceUsesWithAndRemove(arg_inst->Arg(0));
return;
}
}
Expand All @@ -211,7 +211,7 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
}
const IR::Value rhs{inst.Arg(1)};
if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
inst.ReplaceUsesWith(inst.Arg(0));
inst.ReplaceUsesWithAndRemove(inst.Arg(0));
return;
}
}
Expand All @@ -226,7 +226,8 @@ void FoldCmpClass(IR::Block& block, IR::Inst& inst) {
} else if ((class_mask & IR::FloatClassFunc::Finite) == IR::FloatClassFunc::Finite) {
IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
const IR::F32 value = IR::F32{inst.Arg(0)};
inst.ReplaceUsesWith(ir.LogicalNot(ir.LogicalOr(ir.FPIsInf(value), ir.FPIsInf(value))));
inst.ReplaceUsesWithAndRemove(
ir.LogicalNot(ir.LogicalOr(ir.FPIsInf(value), ir.FPIsInf(value))));
} else {
UNREACHABLE();
}
Expand All @@ -237,7 +238,7 @@ void FoldReadLane(IR::Inst& inst) {
IR::Inst* prod = inst.Arg(0).InstRecursive();
while (prod->GetOpcode() == IR::Opcode::WriteLane) {
if (prod->Arg(2).U32() == lane) {
inst.ReplaceUsesWith(prod->Arg(1));
inst.ReplaceUsesWithAndRemove(prod->Arg(1));
return;
}
prod = prod->Arg(0).InstRecursive();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ void LowerSharedMemToRegisters(IR::Program& program) {
});
ASSERT(it != ds_writes.end());
// Replace data read with value written.
inst.ReplaceUsesWith((*it)->Arg(1));
inst.ReplaceUsesWithAndRemove((*it)->Arg(1));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/shader_recompiler/ir/passes/resource_tracking_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ void PatchImageSampleInstruction(IR::Block& block, IR::Inst& inst, Info& info,
}
return ir.ImageSampleImplicitLod(handle, coords, bias, offset, inst_info);
}();
inst.ReplaceUsesWith(new_inst);
inst.ReplaceUsesWithAndRemove(new_inst);
}

void PatchImageInstruction(IR::Block& block, IR::Inst& inst, Info& info, Descriptors& descriptors) {
Expand Down
9 changes: 8 additions & 1 deletion src/shader_recompiler/ir/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,13 @@ class Inst : public boost::intrusive::list_base_hook<> {
void Invalidate();
void ClearArgs();

void ReplaceUsesWith(Value replacement);
void ReplaceUsesWithAndRemove(Value replacement) {
ReplaceUsesWith(replacement, false);
}

void ReplaceUsesWith(Value replacement) {
ReplaceUsesWith(replacement, true);
}

void ReplaceOpcode(IR::Opcode opcode);

Expand Down Expand Up @@ -212,6 +218,7 @@ class Inst : public boost::intrusive::list_base_hook<> {

void Use(Inst* used, u32 operand);
void UndoUse(Inst* used, u32 operand);
void ReplaceUsesWith(Value replacement, bool preserve);

IR::Opcode op{};
u32 flags{};
Expand Down
1 change: 1 addition & 0 deletions src/shader_recompiler/recompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ IR::Program TranslateProgram(std::span<const u32> code, Pools& pools, Info& info

// Run optimization passes
Shader::Optimization::SsaRewritePass(program.post_order_blocks);
Shader::Optimization::IdentityRemovalPass(program.blocks);
Shader::Optimization::ConstantPropagationPass(program.post_order_blocks);
if (program.info.stage != Stage::Compute) {
Shader::Optimization::LowerSharedMemToRegisters(program);
Expand Down

0 comments on commit 4773b19

Please sign in to comment.