Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pairing of function parameters. #5225

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 90 additions & 11 deletions source/diff/diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,59 @@ class Differ {
std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
match_group);

// Bucket `src_ids` and `dst_ids` by the key ids returned by `get_group`, and
// then call `match_group` on pairs of buckets whose key ids are matched with
// each other.
//
// For example, suppose we want to pair up groups of instructions with the
// same type. Naturally, the source instructions refer to their types by their
// ids in the source, and the destination instructions use destination type
// ids, so simply comparing source and destination type ids as integers, as
// `GroupIdsAndMatch` would do, is meaningless. But if a prior call to
// `MatchTypeIds` has established type matches between the two modules, then
// we can consult those to pair source and destination buckets whose types are
// equivalent.
//
// Suppose our input groups are as follows:
//
// - src_ids: { 1 -> 100, 2 -> 300, 3 -> 100, 4 -> 200 }
// - dst_ids: { 5 -> 10, 6 -> 20, 7 -> 10, 8 -> 300 }
//
// Here, `X -> Y` means that the instruction with SPIR-V id `X` is a member of
// the group, and `Y` is the id of its type. If we use
// `Differ::GroupIdsHelperGetTypeId` for `get_group`, then
// `get_group(X) == Y`.
//
// These instructions are bucketed by type as follows:
//
// - source: [1, 3] -> 100
// [4] -> 200
// [2] -> 300
//
// - destination: [5, 7] -> 10
// [6] -> 20
// [8] -> 300
//
// Now suppose that we have previously matched up src type 100 with dst type
// 10, and src type 200 with dst type 20, but no other types are matched.
//
// Then `match_group` is called twice:
// - Once with ([1,3], [5, 7]), corresponding to 100/10
// - Once with ([4],[6]), corresponding to 200/20
//
// The source type 300 isn't matched with anything, so the fact that there's a
// destination type 300 is irrelevant, and thus 2 and 8 are never passed to
// `match_group`.
//
// This function isn't specific to types; it simply buckets by the ids
// returned from `get_group`, and consults existing matches to pair up the
// resulting buckets.
void GroupIdsAndMatchByMappedId(
const IdGroup& src_ids, const IdGroup& dst_ids,
uint32_t (Differ::*get_group)(const IdInstructions&, uint32_t),
std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
match_group);

// Helper functions that determine if two instructions match
bool DoIdsMatch(uint32_t src_id, uint32_t dst_id);
bool DoesOperandMatch(const opt::Operand& src_operand,
Expand Down Expand Up @@ -889,6 +942,37 @@ void Differ::GroupIdsAndMatch(
}
}

void Differ::GroupIdsAndMatchByMappedId(
const IdGroup& src_ids, const IdGroup& dst_ids,
uint32_t (Differ::*get_group)(const IdInstructions&, uint32_t),
std::function<void(const IdGroup& src_group, const IdGroup& dst_group)>
match_group) {
// Group the ids based on a key (get_group)
std::map<uint32_t, IdGroup> src_groups;
std::map<uint32_t, IdGroup> dst_groups;

GroupIds<uint32_t>(src_ids, true, &src_groups, get_group);
GroupIds<uint32_t>(dst_ids, false, &dst_groups, get_group);

// Iterate over pairs of groups whose keys map to each other.
for (const auto& iter : src_groups) {
const uint32_t& src_key = iter.first;
const IdGroup& src_group = iter.second;

if (src_key == 0) {
continue;
}

if (id_map_.IsSrcMapped(src_key)) {
const uint32_t& dst_key = id_map_.MappedDstId(src_key);
const IdGroup& dst_group = dst_groups[dst_key];

// Let the caller match the groups as appropriate.
match_group(src_group, dst_group);
}
}
}

bool Differ::DoIdsMatch(uint32_t src_id, uint32_t dst_id) {
assert(dst_id != 0);
return id_map_.MappedDstId(src_id) == dst_id;
Expand Down Expand Up @@ -1419,7 +1503,6 @@ void Differ::MatchTypeForwardPointersByName(const IdGroup& src,
GroupIdsAndMatch<std::string>(
src, dst, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {

// Match only if there's a unique forward declaration with this debug
// name.
if (src_group.size() == 1 && dst_group.size() == 1) {
Expand Down Expand Up @@ -1574,6 +1657,8 @@ void Differ::BestEffortMatchFunctions(const IdGroup& src_func_ids,

id_map_.MapIds(match_result.src_id, match_result.dst_id);

MatchFunctionParamIds(src_funcs_[match_result.src_id],
dst_funcs_[match_result.dst_id]);
MatchIdsInFunctionBodies(src_func_insts.at(match_result.src_id),
dst_func_insts.at(match_result.dst_id),
match_result.src_match, match_result.dst_match, 0);
Expand All @@ -1598,7 +1683,6 @@ void Differ::MatchFunctionParamIds(const opt::Function* src_func,
GroupIdsAndMatch<std::string>(
src_params, dst_params, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {

// There shouldn't be two parameters with the same name, so the ids
// should match. There is nothing restricting the SPIR-V however to have
// two parameters with the same name, so be resilient against that.
Expand All @@ -1609,17 +1693,17 @@ void Differ::MatchFunctionParamIds(const opt::Function* src_func,

// Then match the parameters by their type. If there are multiple of them,
// match them by their order.
GroupIdsAndMatch<uint32_t>(
src_params, dst_params, 0, &Differ::GroupIdsHelperGetTypeId,
GroupIdsAndMatchByMappedId(
src_params, dst_params, &Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {

const size_t shared_param_count =
std::min(src_group_by_type_id.size(), dst_group_by_type_id.size());

for (size_t param_index = 0; param_index < shared_param_count;
++param_index) {
id_map_.MapIds(src_group_by_type_id[0], dst_group_by_type_id[0]);
id_map_.MapIds(src_group_by_type_id[param_index],
dst_group_by_type_id[param_index]);
}
});
}
Expand Down Expand Up @@ -2126,15 +2210,13 @@ void Differ::MatchTypeForwardPointers() {
spv::StorageClass::Max, &Differ::GroupIdsHelperGetTypePointerStorageClass,
[this](const IdGroup& src_group_by_storage_class,
const IdGroup& dst_group_by_storage_class) {

// Group them further by the type they are pointing to and loop over
// them.
GroupIdsAndMatch<spv::Op>(
src_group_by_storage_class, dst_group_by_storage_class,
spv::Op::Max, &Differ::GroupIdsHelperGetTypePointerTypeOp,
[this](const IdGroup& src_group_by_type_op,
const IdGroup& dst_group_by_type_op) {

// Group them even further by debug info, if possible and match by
// debug name.
MatchTypeForwardPointersByName(src_group_by_type_op,
Expand Down Expand Up @@ -2378,7 +2460,6 @@ void Differ::MatchFunctions() {
GroupIdsAndMatch<std::string>(
src_func_ids, dst_func_ids, "", &Differ::GetSanitizedName,
[this](const IdGroup& src_group, const IdGroup& dst_group) {

// If there is a single function with this name in src and dst, it's a
// definite match.
if (src_group.size() == 1 && dst_group.size() == 1) {
Expand All @@ -2392,7 +2473,6 @@ void Differ::MatchFunctions() {
&Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {

if (src_group_by_type_id.size() == 1 &&
dst_group_by_type_id.size() == 1) {
id_map_.MapIds(src_group_by_type_id[0],
Expand Down Expand Up @@ -2437,7 +2517,6 @@ void Differ::MatchFunctions() {
src_func_ids, dst_func_ids, 0, &Differ::GroupIdsHelperGetTypeId,
[this](const IdGroup& src_group_by_type_id,
const IdGroup& dst_group_by_type_id) {

BestEffortMatchFunctions(src_group_by_type_id, dst_group_by_type_id,
src_func_insts_, dst_func_insts_);
});
Expand Down
44 changes: 17 additions & 27 deletions test/diff/diff_files/different_decorations_fragment_autogen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ OpFunctionEnd
; Version: 1.6
; Generator: Khronos SPIR-V Tools Assembler; 0
-; Bound: 82
+; Bound: 92
+; Bound: 89
; Schema: 0
OpCapability Shader
OpMemoryModel Logical GLSL450
Expand Down Expand Up @@ -1030,8 +1030,7 @@ OpFunctionEnd
+OpDecorate %83 DescriptorSet 0
+OpDecorate %83 Binding 0
OpDecorate %32 RelaxedPrecision
-OpDecorate %33 RelaxedPrecision
+OpDecorate %84 RelaxedPrecision
OpDecorate %33 RelaxedPrecision
OpDecorate %36 RelaxedPrecision
OpDecorate %37 RelaxedPrecision
OpDecorate %38 RelaxedPrecision
Expand All @@ -1040,10 +1039,8 @@ OpFunctionEnd
OpDecorate %42 RelaxedPrecision
OpDecorate %43 RelaxedPrecision
OpDecorate %48 RelaxedPrecision
-OpDecorate %49 RelaxedPrecision
-OpDecorate %50 RelaxedPrecision
+OpDecorate %85 RelaxedPrecision
+OpDecorate %86 RelaxedPrecision
OpDecorate %49 RelaxedPrecision
OpDecorate %50 RelaxedPrecision
OpDecorate %52 RelaxedPrecision
OpDecorate %53 RelaxedPrecision
OpDecorate %54 RelaxedPrecision
Expand Down Expand Up @@ -1082,13 +1079,13 @@ OpFunctionEnd
%61 = OpTypeVoid
%69 = OpConstant %16 0
%78 = OpConstant %16 1
+%88 = OpTypePointer Private %2
+%85 = OpTypePointer Private %2
%3 = OpTypePointer Input %2
%7 = OpTypePointer UniformConstant %6
%10 = OpTypePointer UniformConstant %9
%13 = OpTypePointer Uniform %12
%19 = OpTypePointer Uniform %18
+%89 = OpTypePointer Private %2
+%86 = OpTypePointer Private %2
%21 = OpTypePointer Output %2
%28 = OpTypePointer Uniform %27
%30 = OpTypePointer Function %2
Expand All @@ -1106,19 +1103,16 @@ OpFunctionEnd
%22 = OpVariable %21 Output
-%29 = OpVariable %28 Uniform
+%83 = OpVariable %28 Uniform
+%90 = OpConstant %23 0
+%91 = OpConstant %1 0.5
+%87 = OpConstant %23 0
+%88 = OpConstant %1 0.5
%32 = OpFunction %2 None %31
-%33 = OpFunctionParameter %30
+%84 = OpFunctionParameter %30
%33 = OpFunctionParameter %30
%34 = OpLabel
%36 = OpLoad %6 %8
-%37 = OpLoad %2 %33
+%37 = OpLoad %2 %84
%37 = OpLoad %2 %33
%38 = OpVectorShuffle %35 %37 %37 0 1
%39 = OpImageSampleImplicitLod %2 %36 %38
-%41 = OpLoad %2 %33
+%41 = OpLoad %2 %84
%41 = OpLoad %2 %33
%42 = OpVectorShuffle %35 %41 %41 2 3
%43 = OpConvertFToS %40 %42
%44 = OpLoad %9 %11
Expand All @@ -1127,16 +1121,12 @@ OpFunctionEnd
OpReturnValue %46
OpFunctionEnd
%48 = OpFunction %2 None %47
-%49 = OpFunctionParameter %30
-%50 = OpFunctionParameter %30
+%85 = OpFunctionParameter %30
+%86 = OpFunctionParameter %30
%49 = OpFunctionParameter %30
%50 = OpFunctionParameter %30
%51 = OpLabel
-%52 = OpLoad %2 %49
+%52 = OpLoad %2 %85
%52 = OpLoad %2 %49
%53 = OpVectorShuffle %35 %52 %52 0 1
-%54 = OpLoad %2 %50
+%54 = OpLoad %2 %86
%54 = OpLoad %2 %50
%55 = OpVectorShuffle %35 %54 %54 2 3
%56 = OpCompositeExtract %1 %53 0
%57 = OpCompositeExtract %1 %53 1
Expand All @@ -1154,9 +1144,9 @@ OpFunctionEnd
OpStore %65 %66
%67 = OpFunctionCall %2 %32 %65
-%71 = OpAccessChain %70 %14 %69
+%87 = OpAccessChain %70 %82 %69
+%84 = OpAccessChain %70 %82 %69
-%72 = OpLoad %2 %71
+%72 = OpLoad %2 %87
+%72 = OpLoad %2 %84
OpStore %68 %72
-%74 = OpAccessChain %70 %20 %69 %69
+%74 = OpAccessChain %70 %14 %69 %69
Expand Down
Loading