-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Dynamic shape fx trace #294
Conversation
Essentially there was a bug with the norm op. It tries to achieve polymorphism (in f32 vs f16) with class overloading (the fp16 task subclassed the fp32 task), but this results in incorrect behaviour when combined with the automatic mixed precision pass, as the Op was originally in fp32, which gets reforwarded in the pass, but the implement_cuda schedule template still assumes that the input is in fp32. This results in an array of fp16 inputs reinterpreted as fp32; pointers in c++ silently cast. I think there are two ways to achieve type polymorphism in schedule templates right now.
|
Hi @Aalanli, The second method is our current design. We have a some base operator (matmul, conv2d) that supports arbitrary data types and use auto scheduler to schedule. These base operators will be resove to specialized ones with specialized template, and we should check the special condition in the task definition (like in this case, we should assert the input dtype is fp16). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Aalanli. Overall looks good to me!
@xinli-git could you also have a look at this PR (especially about the normalization part).
# unfortunately, when dynamic=True in torch.compile, there may exist other non-tensor parameters | ||
# in example inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For those dynamic shape, I am wondering if these scalar parameters are act as the shape of the input tensors. If that's the case, we can ignore those scalar parameters.
Say a torch model gives us
sample_inputs = [tensor(['m', 'n'], 'm', 'n']
We can declare the symbol variable for 'm' and 'n' (when we define the symbol tensor) and ignore the 'm' and 'n' scalar parameters.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any clue on this?
|
||
|
||
@register_function(operator.iadd) | ||
def iadd(x: Tensor, y: Tensor): | ||
return ops.add(x, y) | ||
return x + y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the x
and y
could be DynInt
?
To be more specific, the hidet task and their schedule template should make sure: the schedule template strictly implements what the computation defines. We can take both ways you mentioned. For example, our |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes in normalize. In principle, this is the right approach. I left two implementations initially so I could add vector load for the fp16 case in the future.
but now that there is the vector data type that Yaoyao has recently introduced, keeping op
and op_fp16
in a single place is the right way to go, and I intend to do the same for reduce op
check_module(model, [x], atol=1e-2, rtol=1e-2) | ||
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True).cuda().eval() | ||
x = torch.randn(*shape).cuda() | ||
check_module(model, [x], atol=1e-2, rtol=1e-2, dynamic=dynamic) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we been using the CPU path before this change?
x: Tensor = op.inputs[0] | ||
if not is_contiguous_norm(dims, len(x.shape)): | ||
return None | ||
if x.dtype != dtypes.float16 or prod([x.shape[dd] for dd in dims]) % 2 != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removing this is safe for now, but we might need to think about how to handle it when we decide to use 2xfp16 types and the norm size is odd.
@@ -32,15 +29,6 @@ class NormalizeResolveRule(ResolveRule): | |||
2) resolve_generic: Default case, return the output of the regular f32 reduce schedule. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the resolve_fp16 comment above
@yaoyaoding the part about normalize is fine as long as the current CI can pass. thanks for the notification :) |
Reading this again I think the problem is that Is my understanding correct that we should not sub-class operators? We should either write them as seperate classes or have a generate Operator / Task that works for all input data types? Basically this line is causing the problem: https://github.com/hidet-org/hidet/blob/main/python/hidet/graph/operator.py#L166 ? |
That line does not have problem. The reforward will create the task again based on the new inputs and parameters. The problem is that the task did not check the data type. If the task only support one data type, it should explicitly assert that its input has that data type. It it accepts the inputs, then its implement function SHOULD support that. We can sub-class operator like the ElementwiseBinaryOp, UnaryElementwiseOp, etc. |
The key convention here is: keep the task computation definition and the implement function consistent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still not sure what the extra scalar parameters are, let's figure them out before merge this PR.
Thanks @Aalanli ! |
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
…. ) (#294) [Ir][Primitives] add vectorized conversion instructions [Ir][CuTe] add reduce primitives in cute (#295) [Ir][CuTe] add mma primitives (#296) [Ir][CuTe] add other primitives in cute (#297) [Transforms][CuTe] add instruction selection pass (#298) [Transforms][CuTe] add resolve bank conflict pass (#299) [Transforms][CuTe] add resolve auto keywords pass (#300) [Transforms][CuTe] add shared memory allocation pass (#301) [Transforms][CuTe] add vectorize elementwise operation pass (#302) [Transforms][CuTe] add analysis pass (#303) [Transforms][CuTe] add canonicalization pass (#304) [Transforms][CuTe] add deadcode elimination pass (#305) [Transforms][CuTe] refactor cute lowering pass (#306) [Graph][Ops] matmul cute (#307) [Ir] cute miscs (#308) [Tests] cute tests (#309) [Chore] fix ci (#313) --------- Co-authored-by: xiaocenxiaocen <xiao.zhang@centml.ai>
enable the option torch.compile(..., dynamic=True)