-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Softmax Ops #846
PyTorch Softmax Ops #846
Conversation
…o pytensor-pytorch
…ensor-pytorch-softmax
Hi @ricardoV94, I went through this PR: #764 and I observed that at some point in that work, Softmax, LogSoftmax and SoftmaxGrad were added, in this commit. So I'm curious why it was taken out. I see something about cuda, but was hoping to get more context. |
We tried to reduce the scope of that initial PR to get the basics in and merge it sooner rather than later. Dropping the implementation of those Ops was probably just that |
def pytorch_funcify_Softmax(op, **kwargs): | ||
axis = op.axis | ||
|
||
if axis is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, None means all axis in PyTensor, doesn't Pytorch support that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None in Pytorch implicitly means axis 1: https://discuss.pytorch.org/t/implicit-dimension-choice-for-softmax-warning/12314/10
And that funcationality is deprecated now, giving the warning:
UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
torch.nn.functional.softmax(torch.Tensor(x), dim=None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So can we write the function to achieve the same behavior as softmax along all axis? I guess in that case we need to do it manually with exp(x) / torch.sum(exp(x), axis=None)
(with the max subtraction for stability) or is that also not allowed?
We shouldn't raise an error that is related to the torch API, but try to achieve the intended behavior of PyTensor. Whether we use torch.softmax or something else is an implementation detail for PyTensor users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or something like: if axis is None: torch.softmax(x.ravel()).reshape(x.shape)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, reshaping can be used for axis=None
Also, it might be a good idea to make sure x
is a torch.float
dtype. The softmax would fail if the dtype is int or long.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 , what is the expected behaviour when the input is of dtype int?
I'm thinking of converting to float implicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to check the integer case but don't forget to test it as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ricardoV94, I have added the Logsoftmax and Softmaxgrad ops.
I have a question in relation to testing the types. How do I do this?
Everytime I send the functiongraph and the inputs into compare_pytorch_and_py
, the function always ends up with a Tensor of floats somehow.
So can you help with pointers on what other way I can test the actual inputs without some conversion happening in the middle. I am asking because of the need to check that the Ops still work with an input of dtype int.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the test you're creating a matrix
which has a default dtype of float32
or float64
, you can parametrize the test with an explicit dtype, say pytest.mark.parametrize("dtype", ["float64", "int64"])
to test integer types as well, and create x = matrix("x", dtype=dtype)
This is something we should mention in the docs @HarshvirSandhu
CC @HarshvirSandhu for helping with the review as the PR progresses |
@@ -9,7 +9,7 @@ channels: | |||
dependencies: | |||
- python>=3.10 | |||
- compilers | |||
- numpy>=1.17.0 | |||
- numpy>=1.17.0,<2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is because of this error: #827
|
||
def softmax(x): | ||
if not torch.is_floating_point(x): | ||
x = x.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably convert to the output type advertised by PyTensor. The second argument of the funcify function, is node
(now hidden inside **kwargs). If you retrieve node
, you can check the dtype of the output via node.outputs[0].dtype
. This is probably what you should convert to, not necessarily float32
x = x.to(torch.float32) | |
x = x.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually there's a bug in PyTensor Softmax, it also fails if you try to execute with integer dtype in the default backend. I'll open an issue for that.
For now it's enough for the torch dispatch function to raise a NotImplementedError
if the input dtype (which similarly you can get from node.inputs[0].dtype
is an integer, and not try to handle it inside the returned softmax
function.
Same for the other Softmax related functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened an issue: #857
@@ -40,11 +40,14 @@ def dimshuffle(x): | |||
@pytorch_funcify.register(Softmax) | |||
def pytorch_funcify_Softmax(op, **kwargs): | |||
axis = op.axis | |||
dtype = kwargs["node"].outputs[0].dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe better to check the input dtype, because we would fail if we pass an integer. PyTensor could start saying Softmax takes as input integers and outputs floats once we fix it? Sorry if I said the output before
dtype = kwargs["node"].outputs[0].dtype | |
dtype = kwargs["node"].inputs[0].dtype |
Left one comment, after that I think it's ready! |
…/pytensor into pytensor-pytorch-softmax
…ensor-pytorch-softmax
Hi @ricardoV94 , I have made that change and resolved existing merge conflicts. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #846 +/- ##
==========================================
- Coverage 80.98% 80.98% -0.01%
==========================================
Files 169 169
Lines 46985 47025 +40
Branches 11494 11501 +7
==========================================
+ Hits 38052 38084 +32
- Misses 6716 6727 +11
+ Partials 2217 2214 -3
|
Thanks @HAKSOAT |
Description
This PR implements the Softmax Ops for PyTorch.
Related Issue
Checklist
Type of change