-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes #10172
Conversation
Not sure whom I should request for reviews, but it seems simimlar to this PR #9582. So ccing the reviewers there @YuchenJin @junrushao1994 @Mousius |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM and thanks for the contribution! But I am wondering if we can simply perform such compatible casting when constructing Ramp node? (https://github.com/apache/tvm/tree/main/src/tir/ir/expr.cc#L705)
We should add ICHECK(base.is_int() && stride.is_int())
if the two will only be integers.
if (base.dtype().is_int()) { | ||
ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can simply assume that base and stride should be of integer types. However, I also noticed that in
Line 705 in 22c488e
Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) { |
Such assumptions are not checked. I am a bit curious if there will be, say base/stride
in float types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if they can be floats. So I added that conservative line of code that only acts on integers. If they can only be integer we should add that ICHECK
you mention.
@lazycal Using this impl (based on yours) in https://github.com/lazycal/tvm/blob/ffe6649855c4c247f4bb85c9d48c5ca157850a1d/src/tir/ir/expr.cc#L705 fixes the bug you mentioned and might be more general to overcome other hidden ones if we can assume that Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
ICHECK(base.defined());
ICHECK(stride.defined());
ICHECK(base.dtype().is_scalar());
ICHECK(stride.dtype().is_scalar());
ICHECK_GT(lanes, 1);
ICHECK(base.dtype().is_int());
ICHECK(stride.dtype().is_int());
if (base.dtype() != stride.dtype()) {
size_t bits = std::max(base.dtype().bits(), stride.dtype().bits());
DataType dtype = base.dtype().with_bits(bits);
if (base.dtype() != dtype) base = cast(dtype, base);
if (stride.dtype() != dtype) stride = cast(dtype, stride);
}
ObjectPtr<RampNode> node = make_object<RampNode>();
node->dtype = base.dtype().with_lanes(lanes);
node->base = base;
node->stride = stride;
node->lanes = lanes;
node->span = std::move(span);
data_ = std::move(node);
} |
I didn't do what you said because
|
@lazycal Fair consideration! I also tried |
I've just hit a similar error, when compiling an int8 model with tensorized ops (VNNI):
I wonder if this is related. |
…passes (apache#10172) [TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.
The following model
triggers two issues regarding
base
andstride
dtype mismatch inRamp
, one in VectorizeLoop Pass and the other in NarrowDataType Pass. Error message looks likeCheck failed: stride.dtype() == base.dtype() (int32 vs. int64) :
.The fix
int32
. This PR changes it to use the loop variable's dtype.stride
is inferred withint32
, butbase
is not (see the added test case for detail). This PR adds an upcasting when rewriting aRamp
node that hasbase
andstride
inferred with different number of bits.