Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inefficient loops generated with 2D RDom #7374

Closed
adrian-lebioda opened this issue Feb 24, 2023 · 3 comments · Fixed by #7377
Closed

Inefficient loops generated with 2D RDom #7374

adrian-lebioda opened this issue Feb 24, 2023 · 3 comments · Fixed by #7377

Comments

@adrian-lebioda
Copy link
Contributor

Hello,

I encountered issue that loops generated from reduction domain where inefficient. Halide properly detected stop condition for x rdom variable loop and limited iteration in this dimension however for y rdom variable not only it did not detect it properly as loop limit but also put if branch inside inner loop.

I encountered this issue in v14.0.0 and also on release/15.x branch.

I was able to reduce code to this reproduction case (probably it could be reduced further as at some point I just complicated things to try to trigger it, however I think main point is complicating totalEntries expression):

#include "Halide.h"

using namespace Halide;

class BugGenerator : public Generator<BugGenerator>
{
public:
    static constexpr int MAX_ENTRIES = 512;

    Input<Buffer<std::uint32_t, 2>> input {"input"};
    Input<Buffer<bool, 3>> valid {"valid"};
    Output<Buffer<std::uint32_t, 2>> result {"result"};

    Func internal {"internal"};
    Var x, y, index;

    void generate()
    {
        const Expr entry1 = input(x / 2, y / 2);
        const Expr entry2 = input(x / 2, y / 2);

        const auto dim0 = input.dim(0);
        const auto dim1 = input.dim(1);
        const Expr entry3 = input(
            clamp(cast<int>(entry1), dim0.min(), dim0.min() + dim0.extent()),
            clamp(cast<int>(entry2), dim1.min(), dim1.min() + dim1.extent()));
        const Expr totalEntries = 1 + 2 * entry1 + entry3;

        Func entries;
        entries(x, y, index) = Tuple(cast<int>(input(x, y) % 10), y, index);

        RDom rdom(0, MAX_ENTRIES, 0, MAX_ENTRIES, "internalRDOM");
        rdom.where(rdom.y < totalEntries);
        rdom.where(rdom.x < totalEntries);

        const Tuple entry = entries(x, y, rdom.y);
        rdom.where(valid(entry.as_vector()));

        internal(x, y, index) = cast<std::uint32_t>(0);

        const auto currentIndex = clamp(totalEntries - rdom.x - 1, 0, MAX_ENTRIES);
        const Expr currentValue = internal(x, y, currentIndex);
        rdom.where(currentValue < 1024);

        internal(x, y, currentIndex) = select(currentValue < 512,
            currentValue + cast<std::uint32_t>(rdom.y),
            currentValue / Expr(2u) + Expr(2u) * cast<std::uint32_t>(rdom.y));

        RDom sumDom(0, MAX_ENTRIES, "sumRDOM");
        sumDom.where(sumDom.x < totalEntries);

        result(x, y) = sum(sumDom, internal(x, y, sumDom.x));
    }

    void schedule()
    {
        internal.compute_root().parallel(y);
        internal.update(0).parallel(y);

        result.parallel(y);
    }
};

HALIDE_REGISTER_GENERATOR(BugGenerator, bug_generator);

and this interesting fragment of lowered stmt (I added comments to mark problem points):

// Unbound loop iteration
for (internal.s1.internalRDOM$y, 0, 512) {
  let internal.s1.internalRDOM$x.new_max.s = let t414 = input[t387] in (let t415 = int32(t414) in (int32(((input[(max(min(t415, t373), input.min.1)*input.stride.1) + (max(min(t415, t372), input.min.0) - t385)] + (t414*(uint32)2)) + (uint32)1)) + -1))
  let t398 = uint32(internal.s1.internalRDOM$y)
  let t394 = max(min(internal.s1.internalRDOM$x.new_max.s, 511), -1)
  let t395 = internal.s1.internalRDOM$y*valid.stride.2
  for (internal.s1.internalRDOM$x, 0, t394 + 1) {
   // both rdom checked in inner loop (for x it is second check as far as I understand)
   if (let t416 = max(internal.s1.internalRDOM$x, internal.s1.internalRDOM$y) in (let t417 = input[t388] in (let t418 = int32(t417) in (let t419 = input[t387] in (let t420 = int32(t419) in (((t416 < int32(((input[(max(min(t418, t373), input.min.1)*input.stride.1) + (max(min(t418, t372), input.min.0) - t385)] + (t417*(uint32)2)) + (uint32)1))) && uint1(valid[(t389 + int32((input[t390] % (uint32)10))) + t395])) && (t416 < int32(((input[(max(min(t420, t373), input.min.1)*input.stride.1) + (max(min(t420, t372), input.min.0) - t385)] + (t419*(uint32)2)) + (uint32)1))))))))) {
    if (let t421 = input[t387] in (let t422 = int32(t421) in (max(internal.s1.internalRDOM$x, internal.s1.internalRDOM$y) < int32(((input[(max(min(t422, t373), input.min.1)*input.stride.1) + (max(min(t422, t372), input.min.0) - t385)] + (t421*(uint32)2)) + (uint32)1))))) {
     if (let t423 = input[t387] in (let t424 = int32(t423) in (internal[(max(min(int32(((input[(max(min(t424, t373), input.min.1)*input.stride.1) + (max(min(t424, t372), input.min.0) - t385)] + (t423*(uint32)2)) + (uint32)1)) - internal.s1.internalRDOM$x, 513), 1)*t383) + t391] < (uint32)1024))) {
      let t332 = input[t387]
      let t333 = int32(t332)
      let t334.s = int32(((input[(max(min(t333, t373), input.min.1)*input.stride.1) + (max(min(t333, t372), input.min.0) - t385)] + (t332*(uint32)2)) + (uint32)1))
      let t337 = internal[(max(min(t334.s - internal.s1.internalRDOM$x, 513), 1)*t383) + t391]
      internal[((max(min(t334.s - internal.s1.internalRDOM$x, 513), 1) + -1)*t383) + t392] = select(t337 < (uint32)512, t337 + t398, (t337/(uint32)2) + (t398*(uint32)2))
     }
    }
   }
  }
 }

In my original case there is slight difference that my bounds check does not need memory access only calculations however it is emmitted as second condition and first one does memory access.

Not sure if it is possible to force different statement generation or if this is some bug or limitation?

Best regards,
Adrian

@abadams
Copy link
Member

abadams commented Feb 24, 2023

For unsatisfying implementation reasons it's hard for Halide to do the proof that it's safe to remove the if statements inside the loop, but I did figure out why it's not trimming the loop over y. I'll open a PR.

abadams added a commit that referenced this issue Feb 25, 2023
…7377)

* Bounds visitors for min/max were missing single_point mutated case

Partially fixes #7374

* Add test
@adrian-lebioda
Copy link
Contributor Author

Thanks for quick response :)

I checked fix from #7377 and it fixes reproduction case I posted however it turns out that this case was triggering this issue some other way than my original code.

This time I instead of trying to create some synthetic case I basically did my original code calculations of totalEntries and it triggers this issue again:

#include "Halide.h"

using namespace Halide;

class BugGenerator : public Generator<BugGenerator>
{
public:
    static constexpr int MAX_ENTRIES = 512;
    // Start of changes since previous sample
    static constexpr int INPUT_SIZE = 128;
    // End of changes since previous sample

    Input<Buffer<std::uint32_t, 2>> input {"input"};
    Input<Buffer<bool, 3>> valid {"valid"};
    Output<Buffer<std::uint32_t, 2>> result {"result"};

    Func internal {"internal"};
    Var x, y, index;

    void generate()
    {
        // Start of changes since previous sample
        const auto getInput = lambda(x, y, input(clamp(x, 0, INPUT_SIZE), clamp(y, 0, INPUT_SIZE)));

        const Expr size = cast<int>(getInput(x, y) & Expr(0xff00u)) >> Expr(8u);
        const Tuple pos {x, y};
        const Tuple minIndex {
            max(pos[0] - size, 0),
            max(pos[1] - size, 0)};
        const Tuple maxIndex {
            min(pos[0] + size + 1, 0),
            min(pos[1] + size + 1, INPUT_SIZE)};
        const Expr step = max(size / 10, 2);

        const Tuple elements {
            cast<int>(ceil(cast<float>(maxIndex[0] - minIndex[0]) / cast<float>(step))),
            cast<int>(ceil(cast<float>(maxIndex[1] - minIndex[1]) / cast<float>(step)))};
        const Expr totalEntries = 1 + 2 * elements[0] + elements[1];
        // End of changes since previous sample

        Func entries;
        entries(x, y, index) = Tuple(cast<int>(input(x, y) % 10), y, index);

        RDom rdom(0, MAX_ENTRIES, 0, MAX_ENTRIES, "internalRDOM");
        rdom.where(rdom.y < totalEntries);
        rdom.where(rdom.x < rdom.y);

        const Tuple entry = entries(x, y, rdom.y);
        rdom.where(valid(entry.as_vector()));

        internal(x, y, index) = cast<std::uint32_t>(0);

        const auto currentIndex = clamp(totalEntries - rdom.x - 1, 0, MAX_ENTRIES);
        const Expr currentValue = internal(x, y, currentIndex);
        rdom.where(currentValue < 1024);

        internal(x, y, currentIndex) = select(currentValue < 512,
            currentValue + cast<std::uint32_t>(rdom.y),
            currentValue / Expr(2u) + Expr(2u) * cast<std::uint32_t>(rdom.y));

        RDom sumDom(0, MAX_ENTRIES, "sumRDOM");
        sumDom.where(sumDom.x < totalEntries);

        result(x, y) = sum(sumDom, internal(x, y, sumDom.x));
    }

    void schedule()
    {
        // Start of changes since previous sample
        input.dim(0).set_bounds(0, INPUT_SIZE);
        input.dim(1).set_bounds(0, INPUT_SIZE);
        // End of changes since previous sample

        internal.compute_root().parallel(y);
        internal.update(0).parallel(y);

        result.parallel(y);
    }
};

HALIDE_REGISTER_GENERATOR(BugGenerator, bug_generator);

and generated statement:

 // Again this loop iteration is unbound
 for (internal.s1.internalRDOM$y.rebased, 0, 511) {
  if (let t568 = int32((uint32)bitwise_and(input[t543], (uint32)65280)) in (let t569 = float32(max(t568/2560, 2)) in (let t570 = (internal.s1.internalRDOM$y.rebased < ((int32((float32)ceil_f32(float32(((min(((t568/256) + result.min.0) + internal.s1.v0.rebased, -1) + min(((t568/256) - result.min.0) - internal.s1.v0.rebased, 0)) + 1))/t569))*2) + int32((float32)ceil_f32(float32(((min((t568/256) + t542, 127) + min((t568/256) - t542, 0)) + 1))/t569)))) in ((t570 && uint1(valid[((internal.s1.internalRDOM$y.rebased + 1)*valid.stride.2) + (t544 + int32((input[t545] % (uint32)10)))])) && t570)))) {
   if (let t571 = int32((uint32)bitwise_and(input[t543], (uint32)65280)) in (let t572 = float32(max(t571/2560, 2)) in (internal.s1.internalRDOM$y.rebased < ((int32((float32)ceil_f32(float32(((min(((t571/256) + result.min.0) + internal.s1.v0.rebased, -1) + min(((t571/256) - result.min.0) - internal.s1.v0.rebased, 0)) + 1))/t572))*2) + int32((float32)ceil_f32(float32(((min((t571/256) + t542, 127) + min((t571/256) - t542, 0)) + 1))/t572)))))) {
    let t551 = uint32((internal.s1.internalRDOM$y.rebased + 1))
    for (internal.s1.internalRDOM$x, 0, internal.s1.internalRDOM$y.rebased + 1) {
     if (let t573 = int32((uint32)bitwise_and(input[t543], (uint32)65280)) in (let t574 = float32(max(t573/2560, 2)) in (internal[(max(min(((int32((float32)ceil_f32(float32(((min(((t573/256) + result.min.0) + internal.s1.v0.rebased, -1) + min(((t573/256) - result.min.0) - internal.s1.v0.rebased, 0)) + 1))/t574))*2) + int32((float32)ceil_f32(float32(((min((t573/256) + t542, 127) + min((t573/256) - t542, 0)) + 1))/t574))) - internal.s1.internalRDOM$x, 512), 0)*t540) + t546] < (uint32)1024))) {
      let t503 = int32((uint32)bitwise_and(input[t547], (uint32)65280))
      let t504 = float32(max(t503/2560, 2))
      let t505 = (max(min(((int32((float32)ceil_f32(float32(((min(((t503/256) + result.min.0) + internal.s1.v0.rebased, -1) + min(((t503/256) - result.min.0) - internal.s1.v0.rebased, 0)) + 1))/t504))*2) + int32((float32)ceil_f32(float32(((min((t503/256) + t542, 127) + min((t503/256) - t542, 0)) + 1))/t504))) - internal.s1.internalRDOM$x, 512), 0)*t540) + t546
      let t506 = internal[t505]
      internal[t505] = select(t506 < (uint32)512, t506 + t551, (t506/(uint32)2) + (t551*(uint32)2))
     }
    }
   }
  }
 }

@adrian-lebioda
Copy link
Contributor Author

I found this one it is because of missing case for Div node I will try to prepare PR after I test that I implemented this properly.

ardier pushed a commit to ardier/Halide-mutation that referenced this issue Mar 3, 2024
…alide#7377)

* Bounds visitors for min/max were missing single_point mutated case

Partially fixes halide#7374

* Add test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants