Skip to content

Commit

Permalink
Update the list of fused_pairs and run validate_fused_group for speca…
Browse files Browse the repository at this point in the history
…lization definitions too (#6770)

* Update the list of fused_pairs and run validate_fused_group for specialization definitions too.

Fixes #6763.

* Address review comments

* Add const to auto&
  • Loading branch information
vksnk authored May 18, 2022
1 parent 25a3272 commit 13a5470
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/RealizationOrder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,16 @@ void populate_fused_pairs_list(const string &func, const Definition &def,
func, stage_index, fuse_level.var().name());
if (fuse_level.stage_index() == 0) {
parent.definition().schedule().fused_pairs().push_back(pair);
for (auto &s : parent.definition().specializations()) {
s.definition.schedule().fused_pairs().push_back(pair);
}
} else {
internal_assert(fuse_level.stage_index() > 0);
parent.update(fuse_level.stage_index() - 1).schedule().fused_pairs().push_back(pair);
auto &fuse_stage = parent.update(fuse_level.stage_index() - 1);
fuse_stage.schedule().fused_pairs().push_back(pair);
for (auto &s : fuse_stage.specializations()) {
s.definition.schedule().fused_pairs().push_back(pair);
}
}
}

Expand Down
11 changes: 10 additions & 1 deletion src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2369,9 +2369,18 @@ void validate_fused_groups_schedule(const vector<vector<string>> &fused_groups,

validate_fused_group_schedule_helper(
iter->first, 0, iter->second.definition(), env);
for (const auto &s : iter->second.definition().specializations()) {
validate_fused_group_schedule_helper(
iter->first, 0, s.definition, env);
}
for (size_t i = 0; i < iter->second.updates().size(); ++i) {
const auto &update_stage = iter->second.updates()[i];
validate_fused_group_schedule_helper(
iter->first, i + 1, iter->second.updates()[i], env);
iter->first, i + 1, update_stage, env);
for (const auto &s : update_stage.specializations()) {
validate_fused_group_schedule_helper(
iter->first, i + 1, s.definition, env);
}
}
}
}
Expand Down
1 change: 1 addition & 0 deletions test/error/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ tests(GROUPS error
clamp_out_of_range.cpp
compute_with_crossing_edges1.cpp
compute_with_crossing_edges2.cpp
compute_with_fuse_in_specialization.cpp
constrain_wrong_output_buffer.cpp
constraint_uses_non_param.cpp
define_after_realize.cpp
Expand Down
22 changes: 22 additions & 0 deletions test/error/compute_with_fuse_in_specialization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "Halide.h"
#include <stdio.h>

using namespace Halide;

int main(int argc, char **argv) {
Var x("x"), y("y"), f("f");
ImageParam in(Int(16), 2, "in");
Func out0("out0"), out1("out1");
out0(x, y) = 1 * in(x, y);
out1(x, y) = 2 * in(x, y);

out0.vectorize(x, 8, TailStrategy::RoundUp);
out1.vectorize(x, 8, TailStrategy::RoundUp).compute_with(out0, x);

out0.specialize(in.dim(1).stride() == 128).fuse(x, y, f);
Pipeline p({out0, out1});
p.compile_jit();

printf("Success!\n");
return 0;
}

0 comments on commit 13a5470

Please sign in to comment.