Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamo] minor enhancements to attention and register a few functions #345

Merged
merged 2 commits into from
Aug 14, 2023

Conversation

xinli-git
Copy link
Collaborator

Encountered a few minor issues when compiling a transformer-based model using torch.compile with very large batch sizes, submitting the fix here.

@yaoyaoding yaoyaoding requested a review from hjjq August 9, 2023 20:32
Copy link
Member

@hjjq hjjq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @xinli-git, LGTM. Just a minor question about torch_sum

Comment on lines +984 to +1005
@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(x: Tensor, *, dtype: Optional[DataType] = None) -> Tensor:
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=list(range(len(x.shape))), keep_dim=True)
return output


@register_function(torch.sum)
@register_method(torch.Tensor.sum)
def torch_sum(
x: Tensor, dim, keepdim=False, *, dtype: Optional[DataType] = None, out: Optional[Tensor] = None
) -> Tensor:
if out is not None:
raise NotImplementedError("hidet: does not support torch.sum(..., out=...)")
if dtype:
x = x.astype(dtype_from_torch(dtype))
output = ops.sum(x, dims=dim, keep_dim=keepdim)
return output


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need two torch_sums here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the convention for the mean method above. Not entirely sure either as I thought python does not support overloading. Perhaps @yaoyaoding knows?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python itself does not support function overloading. We used the inspect module to support overloading in hidet. This is needed because some pytorch function/methods have multiple signatures.
The implementation can be found at here and here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Copy link
Member

@hjjq hjjq Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. But does torch.Tensor.sum and torch.sum have the same signature? If they do, then no need for overloading? https://pytorch.org/docs/stable/generated/torch.sum.html#torch.sum
Also it doesn't seem that either of them has the out argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it doesn't seem that either of them has the out argument.
right, let me fix this

I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually jump to the signatures in the python code to check the variants of the torch functions:
image
and its interesting that the code has out parameter but the documentation does not have.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the overload is for sum(x, *, dtype) and sum(x, dims, keepdim, ...)?

I see.
Also, keepdim in the first case (L989) should default to False?
Lastly, torch.Tensor.sum seems to have a slightly different signature, where dim has a default value (whereas torch.sum doesn't have a default, making dim mandatory). So in the case below, it will resolve to the first case because of missing dim, and possibly produce wrong results?

a = torch.randn(...)
b = a.sum(keepdim=True)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I usually jump to the signatures in the python code to check the variants of the torch functions: image and its interesting that the code has out parameter but the documentation does not have.

image
Interestingly, my pytorch code doesn't have out. Maybe we have different version/build of torch.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is weird, I see two places where Torch generates

@overload
def xxx(args, )

One under ./_C/__init__.pyi (does not have "out"), another under ./_C/_VariableFunctions.pyi (has "out"). However, both are just generated signatures that don't really represent the actual implementation. I think they are there to make the IDEs work.

The actual implementation should be at aten/src/ATen/native/native_functions.yaml, which has "out".

Even if the actual op does not support "out", having an optional out argument should not break the inspect.Signature.bind function. so we should still be fine, and it would be better to include "out" here

@xinli-git xinli-git merged commit edb6503 into hidet-org:main Aug 14, 2023
@xinli-git xinli-git deleted the minor_enhancements branch August 14, 2023 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants