Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Bugs in HoistIfThenElse #5559

Closed
roastduck opened this issue May 11, 2020 · 7 comments
Closed

[TIR] Bugs in HoistIfThenElse #5559

roastduck opened this issue May 11, 2020 · 7 comments
Assignees

Comments

@roastduck
Copy link
Contributor

roastduck commented May 11, 2020

HoistIfThenElse is a pass currently not enabled in TVM. I tried to enable it in #5553, but there are too many bugs in this pass. Let's fix them first.

BUG 1: HoistIfThenElse transforms

for (n.inner, 0, 2) {
  for (o.inner, 0, 2) {
    if ((((threadIdx.y*2) + n.inner) < 2)) {
      if ((((threadIdx.z*2) + o.inner) < 4)) {
        if ((threadIdx.y < 1)) {
          if ((threadIdx.z < 2)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

into

if ((((threadIdx.y*2) + n.inner) < 2)) {
  if ((threadIdx.y < 1)) {
    if ((threadIdx.z < 2)) {
      for (n.inner, 0, 2) {
        for (o.inner, 0, 2) {
          if ((((threadIdx.z*2) + o.inner) < 4)) {
            tvm_store_matrix_sync(Conv.wmma.accumulator, 16, 16, 16, ((n.inner*2) + o.inner), tvm_access_ptr(type_annotation(), Conv, (((((threadIdx.y*401408) + (n.inner*200704)) + (blockIdx.z*1024)) + (threadIdx.z*512)) + (o.inner*256)), 256, 2), 16, "row_major")
          }
        }
      }
    }
  }
}

Possible cause:

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L295

It only checks whether if_stmt has a preferred position, but that position is not guaranteed to be the current position. Change it to

if (if_position_map.count(if_stmt.get()) &&
    if_position_map.at(if_stmt.get()).as<ForNode>()->loop_var.get() == top_for_var) {

may solve the problem.

BUG 2: src/tir/transforms/split_host_device.cc want the IR to be an SSA form, where each variable can only be defined once. Since we are copying loops into both "then" branches and "else" branches, we have to rename the loop variables in "else" branches to be different from those in "then" branches. I have already written some code for this, see #5553.

BUG 3: IfThenElse nodes containing thread indices should not be hoisted over the definition of the indices. This would happen when Attr node for thread_extent is scheduled into the body of a For node, using a compute_at command. I have already written some code for this, see #5553.

BUG 4:

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/src/tir/pass/hoist_if_then_else.cc#L371

Look at this line. if_stmt can already been updated when running this line. Look at the example below.

for (i, 0, 10) {
  for (j, 0, 10) {
    for (k, 0, 10) {
      if ((i >= 3)) {
        if ((j >= 3)) {
          data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
        }
      }
    }
  }
}

After hoisting j >= 3, if becomes

for (i, 0, 10) {
  for (j, 0, 10) {
    if ((j >= 3)) {
      for (k, 0, 10) {
        if ((i >= 3)) {
          data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
        }
      }
    }
  }
}

Now, when we are hoisting i >= 3, we need to compare and remove

if ((i >= 3)) {
  if ((j >= 3)) {
    data[(((i*100) + (j*10)) + k)] = (data[(((i*100) + (j*10)) + k)] + 0.5f)
  }
}

But j >= 3 has been gone, so RemoveIf fails. We have to track the updating to IfThenElse just like what we did for For.

BUG 5: It is for tests this time.

https://github.com/apache/incubator-tvm/blob/0e877521f454e239f5c44bb88e557801444d81a5/tests/python/unittest/test_tir_pass_hoist_if.py#L175

Why do we expect a ('For', 'j') inside itself? As a potential problem, maybe we should change the variable names to prevent there are two is and two js.

These are all the bugs I found.

Beside, I suggest changing all the for (size_t i = 0; i < xxx.size(); i++) into for (size_t i = 0, n = xxx.size(); i < n; i++), since C++ compiler can't detect this loop invariant.

@kevinthesun Maybe you can have a look.

@tqchen
Copy link
Member

tqchen commented May 15, 2020

@kevinthesun it would be great if you can followup

@kevinthesun
Copy link
Contributor

@roastduck Thank you for bringing these up. This pass was tested only for limited number of cuda conv2d workloads, and not production ready yet. It would be great if you can help fix or improve this pass.

@tqchen
Copy link
Member

tqchen commented May 29, 2020

@roastduck would you be interested in taking over the pass?

@roastduck
Copy link
Contributor Author

I met some difficulties improving this pass. For now, I'm not going to take over it.

This pass massively utilizes low level semantics such as PostOrderVisit (instead of StmtExprMutator) and raw pointers to Object, and it relies on manually tracking the updates to these pointers, which is hard to understand. Maybe we should develop an improved StmtExprMutator, which can track the updates to the nodes.

@tqchen
Copy link
Member

tqchen commented May 29, 2020

We could certainly rewrite the pass completely, instad of the PostOrderVisit

@tqchen
Copy link
Member

tqchen commented Jun 15, 2020

Given that this pass is not product ready and we have not yet migrated this pass to the transform. Perhaps we can remove the pass for now, and then add it back once we have a better impl. Leaving the thread open for a week to see how would everyone think

tqchen added a commit to tqchen/tvm that referenced this issue Jun 27, 2020
This pass has not been migrated to the new transform API,
and contains potential bugs per apache#5559.
Given that it is not being actively used, this PR remove this pass
from the collection.

Followup PRs are more than welcomed to land a better version that
conforms with the new transform API.
@tqchen
Copy link
Member

tqchen commented Jun 27, 2020

#5944 removes this pass for now.

tqchen added a commit that referenced this issue Jun 27, 2020
This pass has not been migrated to the new transform API,
and contains potential bugs per #5559.
Given that it is not being actively used, this PR remove this pass
from the collection.

Followup PRs are more than welcomed to land a better version that
conforms with the new transform API.
@tqchen tqchen closed this as completed Jun 27, 2020
trevor-m pushed a commit to trevor-m/tvm that referenced this issue Jun 30, 2020
This pass has not been migrated to the new transform API,
and contains potential bugs per apache#5559.
Given that it is not being actively used, this PR remove this pass
from the collection.

Followup PRs are more than welcomed to land a better version that
conforms with the new transform API.
zhiics pushed a commit to neo-ai/tvm that referenced this issue Jul 2, 2020
This pass has not been migrated to the new transform API,
and contains potential bugs per apache#5559.
Given that it is not being actively used, this PR remove this pass
from the collection.

Followup PRs are more than welcomed to land a better version that
conforms with the new transform API.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants