Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes #10172

Merged
merged 2 commits into from
Feb 22, 2022

Conversation

lazycal
Copy link
Contributor

@lazycal lazycal commented Feb 4, 2022

Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

The following model

import tvm
from tvm import relay
import numpy as np

xshape = (1, 1, 1)
inp = np.random.uniform(size=xshape).astype(np.int64)

x = relay.var("x", shape=xshape, dtype='int64')
x = relay.cast(x, 'int64')
x = relay.broadcast_to(x, relay.const([1, 2, 2], dtype='int64'))
func = relay.Function(relay.analysis.free_vars(x), -x)
mod = tvm.IRModule.from_expr(func)

with tvm.transform.PassContext(opt_level=0):
    relay.create_executor("debug", mod, tvm.cpu()).evaluate()(inp)

triggers two issues regarding base and stride dtype mismatch in Ramp, one in VectorizeLoop Pass and the other in NarrowDataType Pass. Error message looks like Check failed: stride.dtype() == base.dtype() (int32 vs. int64) :.

The fix

  • During VectorizeLoop, a loop variable will be converted to a ramp but always of dtype int32. This PR changes it to use the loop variable's dtype.
  • During NarrowDataType, it can happen that the stride is inferred with int32, but base is not (see the added test case for detail). This PR adds an upcasting when rewriting a Ramp node that has base and stride inferred with different number of bits.

@lazycal
Copy link
Contributor Author

lazycal commented Feb 4, 2022

Not sure whom I should request for reviews, but it seems simimlar to this PR #9582. So ccing the reviewers there @YuchenJin @junrushao1994 @Mousius

@lazycal lazycal changed the title [TIR] Fix Ramp dtype mismatch in VectorizeLoop and NarrowDataType passes [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes Feb 4, 2022
@tqchen
Copy link
Member

tqchen commented Feb 5, 2022

also cc @vinx13 @hzfan @yzhliu pelease take a look when you have time

Copy link
Contributor

@ganler ganler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and thanks for the contribution! But I am wondering if we can simply perform such compatible casting when constructing Ramp node? (https://github.com/apache/tvm/tree/main/src/tir/ir/expr.cc#L705)

We should add ICHECK(base.is_int() && stride.is_int()) if the two will only be integers.

Comment on lines +262 to +263
if (base.dtype().is_int()) {
ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can simply assume that base and stride should be of integer types. However, I also noticed that in

Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {

Such assumptions are not checked. I am a bit curious if there will be, say base/stride in float types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if they can be floats. So I added that conservative line of code that only acts on integers. If they can only be integer we should add that ICHECK you mention.

@ganler
Copy link
Contributor

ganler commented Feb 7, 2022

@lazycal Using this impl (based on yours) in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that base and stride must be of integers.

Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
  ICHECK(base.defined());
  ICHECK(stride.defined());
  ICHECK(base.dtype().is_scalar());
  ICHECK(stride.dtype().is_scalar());
  ICHECK_GT(lanes, 1);
  ICHECK(base.dtype().is_int());
  ICHECK(stride.dtype().is_int());
  
  if (base.dtype() != stride.dtype()) {
    size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
    DataType dtype = base.dtype().with_bits(bits);
    if (base.dtype() != dtype) base = cast(dtype, base);
    if (stride.dtype() != dtype) stride = cast(dtype, stride);
  }

  ObjectPtr<RampNode> node = make_object<RampNode>();
  node->dtype = base.dtype().with_lanes(lanes);
  node->base = base;
  node->stride = stride;
  node->lanes = lanes;
  node->span = std::move(span);
  data_ = std::move(node);
}

@junrushao
Copy link
Member

This is definitely an interesting bug! CC: @vinx13 @yzhliu @hzfan would be great if you guys could take a look :-)

@lazycal
Copy link
Contributor Author

lazycal commented Feb 7, 2022

@lazycal Using this impl (based on yours) in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that base and stride must be of integers.

Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
  ICHECK(base.defined());
  ICHECK(stride.defined());
  ICHECK(base.dtype().is_scalar());
  ICHECK(stride.dtype().is_scalar());
  ICHECK_GT(lanes, 1);
  ICHECK(base.dtype().is_int());
  ICHECK(stride.dtype().is_int());
  
  if (base.dtype() != stride.dtype()) {
    size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
    DataType dtype = base.dtype().with_bits(bits);
    if (base.dtype() != dtype) base = cast(dtype, base);
    if (stride.dtype() != dtype) stride = cast(dtype, stride);
  }

  ObjectPtr<RampNode> node = make_object<RampNode>();
  node->dtype = base.dtype().with_lanes(lanes);
  node->base = base;
  node->stride = stride;
  node->lanes = lanes;
  node->span = std::move(span);
  data_ = std::move(node);
}

I didn't do what you said because

  • I want to make as little change as possible. Changing a constructor of a fundamental class might be more likely to break things IMO.
  • Also I think the author wrote this class that way must for a reason. Indeed I guess one of them could be to catch such corner cases in pass rewrite algorithms. You are right that implicitly upcasting in the constructor fixes both the bugs, but I am not sure if this is always the desired behavior. For example, there might be some other pass that should have other special handling other than upcasting, and if it forgets to do so, it'd be good to catch that, but implicit upcasting will have it silently ignored or exposed at a much later time.

@ganler
Copy link
Contributor

ganler commented Feb 7, 2022

@lazycal Fair consideration!

I also tried ./tests/scripts/task_python_unittest.sh for the direct fix. It seems to pass all related tests (except a few ones due to my environment), which means at least the unit-tests use Ramp with base and stride in int and is good with casting for int64 and int32 during Ramp node construction.

@masahi
Copy link
Member

masahi commented Feb 21, 2022

I've just hit a similar error, when compiling an int8 model with tensorized ops (VNNI):

  21: tvm::te::MakeTensorize(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, bool)
  20: tvm::te::VerifyTensorizeBody(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::PrimExpr, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::PrimExpr> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&)
  19: tvm::te::MatchTensorizeBody(tvm::te::ComputeOpNode const*, tvm::te::Stage const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::tir::IterVar, tvm::Range, std::hash<tvm::tir::IterVar>, std::equal_to<tvm::tir::IterVar>, std::allocator<std::pair<tvm::tir::IterVar const, tvm::Range> > > const&, std::unordered_map<tvm::te::Tensor, tvm::runtime::Array<tvm::Range, void>, std::hash<tvm::te::Tensor>, std::equal_to<tvm::te::Tensor>, std::allocator<std::pair<tvm::te::Tensor const, tvm::runtime::Array<tvm::Range, void> > > > const&, tvm::te::TensorIntrin const&, tvm::runtime::Map<tvm::tir::Var, tvm::Range, void, void>*)
  18: non-virtual thunk to tvm::tir::StmtExprMutator::VisitExpr(tvm::PrimExpr const&)
  17: _ZZN3tvm3tir11ExprFunctorIFNS_8PrimExprERKS2_EE10InitVTableEvENUlRKNS_7runtime
  16: tvm::te::TensorIntrinMatcher::VisitExpr_(tvm::tir::ReduceNode const*)

  ...

  0: tvm::tir::FloorDiv::FloorDiv(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  File "/home/masa/projects/dev/tvm/src/tir/ir/expr.cc", line 322
TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32

I wonder if this is related.

@masahi masahi self-assigned this Feb 21, 2022
@masahi masahi merged commit d8e39fd into apache:main Feb 22, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
…passes (apache#10172)

[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants