Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE] Promote substituted variable to iter_var's dtype #10571

Merged
merged 5 commits into from
Mar 12, 2022

Conversation

tkonolige
Copy link
Contributor

This fixes a bug where an iteration variable and its associated loop variable have a mismatched dtype.

@mbrookhart @junrushao1994

This fixes a bug where an iteration variable and its associated loop
variable have a mismatched dtype.
@tkonolige
Copy link
Contributor Author

This problem is a lot more complicated than I initially thought. It is actually two problems: 1. IterVar has an invariant that the dtype of its var and the dtype of its domain's extent (dom->extent) match. I fixed all violations of this invariant and added a check for it in the constructor of IterVar. 2. The IterVar of thread/block and the IterVar it binds to need to have the same dtype when computing loop nesting. However, there is no way to know what datatype is needed when constructing the thread/block IterVar, so it defaults to int32. I've added an automatic promotion from thread/block IterVar dtype to the bound IterVar's dtype when computing loop nests, but this is a hack and will cause integer overflow if the bound IterVar exceeds int32::max (should not happen unless we get a GPU/CPU with more than int32::max threads).

@masahi
Copy link
Member

masahi commented Mar 11, 2022

Interesting, does this change remove the need for fixes in the following PRs?

#10519
#9582
#10172

I've been trying to compile the QAT BERT from mlperf via various ways but I consistently get the int32 vs int64 dtype mismatch like:

  2: tvm::arith::SolveInequalitiesToRange(tvm::arith::IntConstraints const&)
  1: tvm::arith::AsConditions(tvm::runtime::Array<tvm::tir::Var, void> const&, tvm::runtime::Map<tvm::tir::Var, tvm::arith::IntGroupBounds, void, void> const&, tvm::runtime::Array<tvm::PrimExpr, void> const&)
  0: tvm::tir::LE::LE(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  File "/home/masa/projects/dev/tvm/src/tir/ir/expr.cc", line 447
TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int32 vs. int64
  1: tvm::tir::ExprMutator::VisitExpr_(tvm::tir::FloorDivNode const*)
  0: tvm::tir::FloorDiv::FloorDiv(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  File "/home/masa/projects/dev/tvm/src/tir/ir/expr.cc", line 322
TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int64 vs. int32
  1: tvm::tir::ArgBinder::BindBuffer(tvm::tir::Buffer const&, tvm::tir::Buffer const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, bool)
  0: tvm::tir::ArgBinder::Bind_(tvm::PrimExpr const&, tvm::PrimExpr const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, bool)
  File "/home/masa/projects/dev/tvm/src/tir/transforms/arg_binder.cc", line 52
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: arg.dtype() == value.dtype() (int32 vs. int64) :

@tkonolige
Copy link
Contributor Author

I think it may fix the middle and top issues, but I'm not positive.

Also, we should think about checking that every For also has matching dtypes between its var and extent.

@masahi
Copy link
Member

masahi commented Mar 11, 2022

just sent a PR fixing a related problem (one of my errors above)

#10584

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junrushao1994 @masahi Any concerns with this?

@masahi masahi merged commit 4cdbf5c into apache:main Mar 12, 2022
lazycal added a commit to lazycal/tvm that referenced this pull request Mar 15, 2022
masahi added a commit to masahi/tvm that referenced this pull request Mar 15, 2022
pfk-beta pushed a commit to pfk-beta/tvm that referenced this pull request Apr 11, 2022
* [TE] Promote substituted variable to iter_var's dtype

This fixes a bug where an iteration variable and its associated loop
variable have a mismatched dtype.

* add check to iter var constructor. fix two bad uses

* proplem is more complicated then I thought

* one more fix

* remove old comments
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants