Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Enable HoistIfThenElse in the default lowering procedure #5553

Closed
wants to merge 5 commits into from

Conversation

roastduck
Copy link
Contributor

Enabling the HoistIfThenElse pass (#3865) in the default lowering procedure. HoistIfThenElse can be very helpful for sparse applications, since LoopPartition cannot eliminate their if statements with dynamic (unknown at compile time) conditions.

Changes:

  • Move HoistIfThenElse from src/tir/pass to src/tir/transform.
  • Added it into lower.
  • Fixed a breaking test.

@tqchen Requesting for a review.

@roastduck
Copy link
Contributor Author

Well, a test for TensorCore fails.

HoistIfThenElse transforms

for (n.inner, 0, 2) {
  for (o.inner, 0, 2) {
    if ((((threadIdx.y*2) + n.inner) < 2)) {
      if ((((threadIdx.z*2) + o.inner) < 4)) {
        if ((threadIdx.y < 1)) {
          if ((threadIdx.z < 2)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

into

if ((((threadIdx.y*2) + n.inner) < 2)) {
  if ((threadIdx.y < 1)) {
    if ((threadIdx.z < 2)) {
      for (n.inner, 0, 2) {
        for (o.inner, 0, 2) {
          if ((((threadIdx.z*2) + o.inner) < 4)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

where the if containing n.inner is wrongly hoisted, seeming like a bug of HoistIfThenElse. I will try to dig it out.

@roastduck
Copy link
Contributor Author

Fixed and added an extra test.

@roastduck roastduck changed the title Enable HoistIfThenElse in the default lowering procedure [TIR] Enable HoistIfThenElse in the default lowering procedure May 10, 2020
@roastduck
Copy link
Contributor Author

Renamed Vars in hoisted else_cases, to be different from those in then_cases, so that SSA verification won't panic.

@roastduck
Copy link
Contributor Author

Fixed a bug where if nodes containing thread indices could be hoisted over the definition of the indices. This would happen when Attr node for thread_extent is scheduled into the body of a For node, using a compute_at command.

@roastduck
Copy link
Contributor Author

Since there are too many bugs in HoistIfThenElse, let's track them in another issue, and enable the pass after they are all fixed. Closing for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant