-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Dynamo] minor enhancements to attention and register a few functions #345
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @xinli-git, LGTM. Just a minor question about torch_sum
@register_function(torch.sum) | ||
@register_method(torch.Tensor.sum) | ||
def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor: | ||
if dtype: | ||
x = x.astype(dtype_from_torch(dtype)) | ||
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True) | ||
return output | ||
|
||
|
||
@register_function(torch.sum) | ||
@register_method(torch.Tensor.sum) | ||
def torch_sum( | ||
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None | ||
) -> Tensor: | ||
if out is not None: | ||
raise NotImplementedError("hidet: does not support torch.sum(..., out=...)") | ||
if dtype: | ||
x = x.astype(dtype_from_torch(dtype)) | ||
output = ops.sum(x, dims=dim, keep_dim=keepdim) | ||
return output | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need two torch_sum
s here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the convention for the mean method above. Not entirely sure either as I thought python does not support overloading. Perhaps @yaoyaoding knows?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. But does torch.Tensor.sum
and torch.sum
have the same signature? If they do, then no need for overloading? https://pytorch.org/docs/stable/generated/torch.sum.html#torch.sum
Also it doesn't seem that either of them has the out
argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also it doesn't seem that either of them has the out argument.
right, let me fix this
I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?
I see.
Also, keepdim in the first case (L989) should default to False?
Lastly, torch.Tensor.sum
seems to have a slightly different signature, where dim
has a default value (whereas torch.sum doesn't have a default, making dim
mandatory). So in the case below, it will resolve to the first case because of missing dim
, and possibly produce wrong results?
a = torch.randn(...)
b = a.sum(keepdim=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is weird, I see two places where Torch generates
@overload
def xxx(args, )
One under ./_C/__init__.pyi
(does not have "out"), another under ./_C/_VariableFunctions.pyi
(has "out"). However, both are just generated signatures that don't really represent the actual implementation. I think they are there to make the IDEs work.
The actual implementation should be at aten/src/ATen/native/native_functions.yaml, which has "out".
Even if the actual op does not support "out", having an optional out argument should not break the inspect.Signature.bind
function. so we should still be fine, and it would be better to include "out" here
Encountered a few minor issues when compiling a transformer-based model using torch.compile with very large batch sizes, submitting the fix here.