Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremy-lunarg committed Sep 10, 2024
1 parent e721617 commit b88032a
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 24 deletions.
145 changes: 122 additions & 23 deletions source/opt/private_to_local_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace opt {
namespace {
constexpr uint32_t kVariableStorageClassInIdx = 0;
constexpr uint32_t kSpvTypePointerTypeIdInIdx = 1;
constexpr uint32_t kEntryPointFunctionIdInIdx = 1;
} // namespace

Pass::Status PrivateToLocalPass::Process() {
Expand All @@ -48,9 +49,11 @@ Pass::Status PrivateToLocalPass::Process() {
continue;
}

Function* target_function = FindLocalFunction(inst);
if (target_function != nullptr) {
variables_to_move.push_back({&inst, target_function});
// TODO: Handle all functions.
// TODO: Might want to return the map from entry points to functions.
std::set<Function*> target_functions = FindLocalFunctions(inst);
if (!target_functions.empty()) {
variables_to_move.push_back({&inst, *(target_functions.begin())});
}
}

Expand Down Expand Up @@ -85,31 +88,55 @@ Pass::Status PrivateToLocalPass::Process() {
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
}

Function* PrivateToLocalPass::FindLocalFunction(const Instruction& inst) const {
bool found_first_use = false;
Function* target_function = nullptr;
std::set<Function*> PrivateToLocalPass::FindLocalFunctions(const Instruction& inst) const {
// Create a map of entry points to the function id containing the first use of the instruction. There
// must only be one function per entry point if we wish to substitute the private variable.
std::unordered_map<Function*, Function*> ep_to_use {};

auto const result_id = inst.result_id();
context()->get_def_use_mgr()->ForEachUser(
inst.result_id(),
[&target_function, &found_first_use, this](Instruction* use) {
BasicBlock* current_block = context()->get_instr_block(use);
if (current_block == nullptr) {
return;
result_id,
[&ep_to_use, &inst, this](Instruction* use) {
BasicBlock* current_block = context()->get_instr_block(use);
if (current_block == nullptr) {
return;
}

// If use is invalid, then remove all references to the current functions.
Function* current_function = current_block->GetParent();
if (!IsValidUse(use)) {
for (auto iter = std::begin(ep_to_use); iter != std::end(ep_to_use);) {
if (iter->second == current_function) {
iter = ep_to_use.erase(iter);
}
}
return;
}

// Find all entry points that can reach the use instruction.
std::unordered_set<Function*> entry_points;
std::set<Function*> visited_ids;
FindEntryPointFuncs(current_function, entry_points, visited_ids);

if (!IsValidUse(use)) {
found_first_use = true;
target_function = nullptr;
// Update the map of entry points. If the function isn't found, then add it. If the function is found,
// then it must match the current function; otherwise, substitution will not be allowed.
for (auto const entry_point : entry_points) {
auto const ep_iter = ep_to_use.find(entry_point);
if(ep_iter == std::end(ep_to_use)) {
ep_to_use[entry_point] = current_function;
} else if(ep_iter->second != current_function) {
return;
}
Function* current_function = current_block->GetParent();
if (!found_first_use) {
found_first_use = true;
target_function = current_function;
} else if (target_function != current_function) {
target_function = nullptr;
}
});
return target_function;
}
});

// TODO: Copy functions from ep_to_use to return variable.
// Return target functions that can substitute the variable.
std::set<Function*> target_functions {};
//if(target_functions.empty())


return target_functions;
} // namespace opt

bool PrivateToLocalPass::MoveVariable(Instruction* variable,
Expand Down Expand Up @@ -232,5 +259,77 @@ bool PrivateToLocalPass::UpdateUses(Instruction* inst) {
return true;
}

bool PrivateToLocalPass::IsEntryPointFunc(const Function* func) const {
for (auto& entry_point : get_module()->entry_points()) {
if (entry_point.GetSingleWordInOperand(kEntryPointFunctionIdInIdx) ==
func->result_id()) {
return true;
}
}

return false;
}

// TODO: Remove?
Instruction PrivateToLocalPass::GetEntryPointFunc(const Function& func) const {
// if(IsEntryPointFunc(func)) {
// return func.DefInst();
// }

Instruction* ep_func {nullptr};
context()->get_def_use_mgr()->WhileEachUser(func.result_id(),
[&ep_func](Instruction* use) {
switch (use->opcode()) {
case spv::Op::OpFunctionCall:
ep_func = use;
return false;
break;
default:
return true;
break;
};
});

return *ep_func;
}

bool PrivateToLocalPass::IsEntryPointFunc(const uint32_t& func_id) const {
for (auto& entry_point : get_module()->entry_points()) {
if (entry_point.GetSingleWordInOperand(kEntryPointFunctionIdInIdx) == func_id) {
return true;
}
}

return false;
}

// A function may be reached from more than one entry point.
void PrivateToLocalPass::FindEntryPointFuncs(Function* func,
std::unordered_set<Function*>& entry_points,
std::set<Function*>& visited_funcs) const {
// Ignore cycles. Stop if we've visited this function already.
if(visited_funcs.find(func) != std::end(visited_funcs)) {
return;
} else {
visited_funcs.insert(func);
}

if(IsEntryPointFunc(func)) {
entry_points.insert(func);
}

context()->get_def_use_mgr()->ForEachUser(func->result_id(), [this, &entry_points, &visited_funcs](Instruction* use) {
switch (use->opcode()) {
case spv::Op::OpFunctionCall: {
auto current_function = context()->get_instr_block(use)->GetParent();
FindEntryPointFuncs(current_function, entry_points, visited_funcs);
break;
}
default:
break;
};
});
}

} // namespace opt
} // namespace spvtools
11 changes: 10 additions & 1 deletion source/opt/private_to_local_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ class PrivateToLocalPass : public Pass {
// class of |function|. Returns false if the variable could not be moved.
bool MoveVariable(Instruction* variable, Function* function);

// TODO: Update the comment.
// |inst| is an instruction declaring a variable. If that variable is
// referenced in a single function and all of uses are valid as defined by
// |IsValidUse|, then that function is returned. Otherwise, the return
// value is |nullptr|.
Function* FindLocalFunction(const Instruction& inst) const;
std::set<Function*> FindLocalFunctions(const Instruction& inst) const;

// Returns true is |inst| is a valid use of a pointer. In this case, a
// valid use is one where the transformation is able to rewrite the type to
Expand All @@ -65,6 +66,14 @@ class PrivateToLocalPass : public Pass {
// change of the base pointer now pointing to the function storage class.
bool UpdateUse(Instruction* inst, Instruction* user);
bool UpdateUses(Instruction* inst);

bool IsEntryPointFunc(const Function* func) const;
Instruction GetEntryPointFunc(const Function& func) const;

bool IsEntryPointFunc(const uint32_t& func_id) const;
void FindEntryPointFuncs(Function* func,
std::unordered_set<Function*>& entry_points,
std::set<Function*>& visited_funcs) const;
};

} // namespace opt
Expand Down
46 changes: 46 additions & 0 deletions test/opt/private_to_local_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,52 @@ TEST_F(PrivateToLocalTest, DebugPrivateToLocal) {
SinglePassRunAndMatch<PrivateToLocalPass>(text, true);
}

TEST_F(PrivateToLocalTest, TwoEntryPoints) {
const std::string text = R"(
; CHECK-NOT: OpEntryPoint GLCompute %foo "foo" %in %priv1 %priv2
; CHECK: OpEntryPoint GLCompute %foo "foo" %in
; CHECK: %priv1 = OpVariable {{%\w+}} Function
; CHECK: %priv2 = OpVariable {{%\w+}} Function
OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %foo "foo" %in %priv1 %priv2
OpExecutionMode %foo LocalSize 1 1 1
OpName %foo "foo"
OpName %in "in"
OpName %priv1 "priv1"
OpName %priv2 "priv2"
%void = OpTypeVoid
%int = OpTypeInt 32 0
%ptr_ssbo_int = OpTypePointer StorageBuffer %int
%ptr_private_int = OpTypePointer Private %int
%in = OpVariable %ptr_ssbo_int StorageBuffer
%priv1 = OpVariable %ptr_private_int Private
%priv2 = OpVariable %ptr_private_int Private
%void_fn = OpTypeFunction %void
%foo = OpFunction %void None %void_fn
%entry = OpLabel
%1 = OpFunctionCall %void %bar1
%2 = OpFunctionCall %void %bar2
OpReturn
OpFunctionEnd
%bar1 = OpFunction %void None %void_fn
%3 = OpLabel
%ld1 = OpLoad %int %in
OpStore %priv1 %ld1
OpReturn
OpFunctionEnd
%bar2 = OpFunction %void None %void_fn
%4 = OpLabel
%ld2 = OpLoad %int %in
OpStore %priv2 %ld2
OpReturn
OpFunctionEnd
)";

SetTargetEnv(SPV_ENV_UNIVERSAL_1_4);
SinglePassRunAndMatch<PrivateToLocalPass>(text, true);
}

} // namespace
} // namespace opt
} // namespace spvtools

0 comments on commit b88032a

Please sign in to comment.