Skip to content

Commit

Permalink
Add Stage::unscheduled()
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams committed Feb 18, 2022
1 parent 0786dd4 commit be1269b
Show file tree
Hide file tree
Showing 13 changed files with 34 additions and 24 deletions.
4 changes: 2 additions & 2 deletions apps/fft/fft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -872,8 +872,8 @@ ComplexFunc fft2d_r2c(Func r,
dft.update(5).allow_race_conditions().vectorize(n0z2, vector_size);

// Intentionally serial
dft.update(0);
dft.update(3);
dft.update(0).unscheduled();
dft.update(3).unscheduled();

// Our result is undefined outside these bounds.
dft.bound(n0, 0, N0);
Expand Down
2 changes: 1 addition & 1 deletion apps/linear_algebra/src/blas_l1_generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class AXPYGenerator : public Generator<AXPYGenerator<T>> {
Var ii("ii");
result_.update().vectorize(vecs, vec_size);
}
result_.update(1); // Leave the tail unvectorized
result_.update(1).unscheduled(); // Leave the tail unvectorized

result_.bound(i, 0, x_.width());
result_.dim(0).set_bounds(0, x_.width());
Expand Down
9 changes: 6 additions & 3 deletions src/Func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1928,6 +1928,11 @@ std::string Stage::source_location() const {
return definition.source_location();
}

void Stage::unscheduled() {
user_assert(!definition.schedule().touched()) << "Stage::unscheduled called on an update definition with a schedule\n";
definition.schedule().touched() = true;
}

void Func::invalidate_cache() {
if (pipeline_.defined()) {
pipeline_.invalidate_cache();
Expand Down Expand Up @@ -2760,9 +2765,7 @@ void Func::debug_to_file(const string &filename) {
Stage Func::update(int idx) {
user_assert(idx < num_update_definitions()) << "Call to update with index larger than last defined update stage for Func \"" << name() << "\".\n";
invalidate_cache();
Definition d = func.update(idx);
d.schedule().touched() = true;
return Stage(func, d, idx + 1);
return Stage(func, func.update(idx), idx + 1);
}

Func::operator Stage() const {
Expand Down
6 changes: 6 additions & 0 deletions src/Func.h
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,12 @@ class Stage {
* empty string if no debug symbols were found or the debug
* symbols were not understood. Works on OS X and Linux only. */
std::string source_location() const;

/** Assert that this stage has intentionally been given no schedule, and
* suppress the warning about unscheduled update definitions that would
* otherwise fire. This counts as a schedule, so calling this twice on the
* same Stage will fail the assertion. */
void unscheduled();
};

// For backwards compatibility, keep the ScheduleHandle name.
Expand Down
2 changes: 1 addition & 1 deletion src/ScheduleFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2083,7 +2083,7 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_
<< " has not been scheduled, even though some other"
<< " definitions have been. You may have forgotten to"
<< " schedule it. If this was intentional, call "
<< f.name() << ".update(" << i << ") to suppress"
<< f.name() << ".update(" << i << ").unscheduled() to suppress"
<< " this warning.\n";
}
}
Expand Down
3 changes: 3 additions & 0 deletions test/correctness/atomics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,9 @@ void test_predicated_hist(const Backend &backend) {
hist(im(r2)) -= cast<T>(1);
hist(im(r2)) = min(hist(im(r2)) + cast<T>(1), cast<T>(100));

hist.update(3).unscheduled();
hist.update(4).unscheduled();

hist.compute_root();
for (int update_id = 0; update_id < 3; update_id++) {
switch (backend) {
Expand Down
14 changes: 6 additions & 8 deletions test/correctness/compute_with.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ int multiple_fuse_group_test() {
p.fuse(x, y, t).parallel(t);
h.fuse(x, y, t).parallel(t);
h.compute_with(p, t);
h.update(0); // unfused
h.update(1); // unfused
h.update(2); // unfused
h.update(0).unscheduled();
h.update(1).unscheduled();
h.update(2).unscheduled();

f.update(0).compute_with(g, y, LoopAlignStrategy::AlignEnd);
f.compute_with(g, x);
Expand Down Expand Up @@ -1280,9 +1280,8 @@ int update_stage_test() {
g.compute_root();
f.compute_root();

f.update(0).unscheduled();
f.update(1).compute_with(g.update(0), y);
f.update(0); // unfused
g.update(1); // unfused

g.bound(x, 0, g_size).bound(y, 0, g_size);
f.bound(x, 0, f_size).bound(y, 0, f_size);
Expand Down Expand Up @@ -1356,7 +1355,6 @@ int update_stage2_test() {

f.update(0).compute_with(g.update(0), y);
f.update(1).compute_with(g.update(0), y);
g.update(1); // unfused

g.bound(x, 0, g_size).bound(y, 0, g_size);
f.bound(x, 0, f_size).bound(y, 0, f_size);
Expand Down Expand Up @@ -1665,8 +1663,8 @@ int update_stage_diagonal_test() {

f.update(1).compute_with(g.update(0), y);
g.update(0).compute_with(h, y);
f.update(0);
g.update(1);
f.update(0).unscheduled();
g.update(1).unscheduled();

g.bound(x, 0, g_size).bound(y, 0, g_size);
f.bound(x, 0, f_size).bound(y, 0, f_size);
Expand Down
2 changes: 1 addition & 1 deletion test/correctness/extern_bounds_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ int main(int argc, char **argv) {
f1.compute_at(g, y);
f2.compute_at(g, x);
g.reorder(y, x).vectorize(y, 4);
g.update();
g.update().unscheduled();

g.infer_input_bounds({W, H});

Expand Down
2 changes: 1 addition & 1 deletion test/correctness/named_updates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ int main(int argc, char **argv) {
more_updates.b.vectorize(r, 4);
more_updates.c.vectorize(r, 4);

f.update(); // fix_first isn't scheduled
f.update().unscheduled(); // fix_first isn't scheduled
}

// Define the same thing without all the weird syntax and without
Expand Down
2 changes: 1 addition & 1 deletion test/correctness/parallel_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ int main(int argc, char **argv) {
sum_rows.compute_root().vectorize(i, 4).parallel(j);
sum_rows.update().parallel(j);
sum_cols.compute_root().vectorize(j, 4);
sum_cols.update();
sum_cols.update().unscheduled();
out.output_buffer().dim(0).set_bounds(0, 256);

Buffer<int> result = out.realize({256});
Expand Down
4 changes: 2 additions & 2 deletions test/correctness/sliding_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ int main(int argc, char **argv) {
f(x, y) = call_count(f(x, y));

f.unroll(y, 2);
f.update(0);
f.update(1);
f.update(0).unscheduled();
f.update(1).unscheduled();

Func g("g");
g(x, y) = f(x, y) + f(x, y - 1) + f(x, y - 2);
Expand Down
6 changes: 3 additions & 3 deletions test/correctness/tuple_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int main(int argc, char **argv) {
f.hexagon(y).vectorize(x, 32);
}
for (int i = 0; i < 10; i++) {
f.update(i);
f.update(i).unscheduled();
if (i & 1) {
if (target.has_gpu_feature()) {
f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16);
Expand Down Expand Up @@ -102,7 +102,7 @@ int main(int argc, char **argv) {

// Schedule the even update steps on the gpu
for (int i = 0; i < 10; i++) {
f.update(i);
f.update(i).unscheduled();
if (i & 1) {
if (target.has_gpu_feature()) {
f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16);
Expand Down Expand Up @@ -144,7 +144,7 @@ int main(int argc, char **argv) {

// Schedule the even update steps on the gpu
for (int i = 0; i < 10; i++) {
f.update(i);
f.update(i).unscheduled();
if ((i & 1) == 0) {
if (target.has_gpu_feature()) {
f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16);
Expand Down
2 changes: 1 addition & 1 deletion test/correctness/vectorized_initialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ int main(int argc, char **argv) {
f(x) = x;
f(r) = f(r - 1) + f(r + 1);
f.compute_root().vectorize(x, 4);
f.update();
f.update().unscheduled();

g(x) = f(x);
Buffer<int> result = g.realize({4});
Expand Down

0 comments on commit be1269b

Please sign in to comment.