Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Softmax Ops #846

Merged
merged 51 commits into from
Jun 28, 2024
Merged

Conversation

HAKSOAT
Copy link
Contributor

@HAKSOAT HAKSOAT commented Jun 23, 2024

Description

This PR implements the Softmax Ops for PyTorch.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@HAKSOAT
Copy link
Contributor Author

HAKSOAT commented Jun 23, 2024

Hi @ricardoV94,

I went through this PR: #764 and I observed that at some point in that work, Softmax, LogSoftmax and SoftmaxGrad were added, in this commit.

So I'm curious why it was taken out. I see something about cuda, but was hoping to get more context.

@ricardoV94
Copy link
Member

Hi @ricardoV94,

I went through this PR: #764 and I observed that at some point in that work, Softmax, LogSoftmax and SoftmaxGrad were added, in this commit.

So I'm curious why it was taken out. I see something about cuda, but was hoping to get more context.

We tried to reduce the scope of that initial PR to get the basics in and merge it sooner rather than later. Dropping the implementation of those Ops was probably just that

def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis

if axis is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, None means all axis in PyTensor, doesn't Pytorch support that case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None in Pytorch implicitly means axis 1: https://discuss.pytorch.org/t/implicit-dimension-choice-for-softmax-warning/12314/10

And that funcationality is deprecated now, giving the warning:

UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
torch.nn.functional.softmax(torch.Tensor(x), dim=None)

Copy link
Member

@ricardoV94 ricardoV94 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So can we write the function to achieve the same behavior as softmax along all axis? I guess in that case we need to do it manually with exp(x) / torch.sum(exp(x), axis=None) (with the max subtraction for stability) or is that also not allowed?

We shouldn't raise an error that is related to the torch API, but try to achieve the intended behavior of PyTensor. Whether we use torch.softmax or something else is an implementation detail for PyTensor users.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or something like: if axis is None: torch.softmax(x.ravel()).reshape(x.shape)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, reshaping can be used for axis=None
Also, it might be a good idea to make sure x is a torch.float dtype. The softmax would fail if the dtype is int or long.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 , what is the expected behaviour when the input is of dtype int?

I'm thinking of converting to float implicitly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have to check the integer case but don't forget to test it as well

Copy link
Contributor Author

@HAKSOAT HAKSOAT Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ricardoV94, I have added the Logsoftmax and Softmaxgrad ops.

I have a question in relation to testing the types. How do I do this?

Everytime I send the functiongraph and the inputs into compare_pytorch_and_py, the function always ends up with a Tensor of floats somehow.

So can you help with pointers on what other way I can test the actual inputs without some conversion happening in the middle. I am asking because of the need to check that the Ops still work with an input of dtype int.

Copy link
Member

@ricardoV94 ricardoV94 Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the test you're creating a matrix which has a default dtype of float32 or float64, you can parametrize the test with an explicit dtype, say pytest.mark.parametrize("dtype", ["float64", "int64"]) to test integer types as well, and create x = matrix("x", dtype=dtype)

This is something we should mention in the docs @HarshvirSandhu

@ricardoV94
Copy link
Member

CC @HarshvirSandhu for helping with the review as the PR progresses

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jun 25, 2024
@@ -9,7 +9,7 @@ channels:
dependencies:
- python>=3.10
- compilers
- numpy>=1.17.0
- numpy>=1.17.0,<2
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is because of this error: #827


def softmax(x):
if not torch.is_floating_point(x):
x = x.to(torch.float32)
Copy link
Member

@ricardoV94 ricardoV94 Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably convert to the output type advertised by PyTensor. The second argument of the funcify function, is node (now hidden inside **kwargs). If you retrieve node, you can check the dtype of the output via node.outputs[0].dtype. This is probably what you should convert to, not necessarily float32

Suggested change
x = x.to(torch.float32)
x = x.to(torch.float32)

Copy link
Member

@ricardoV94 ricardoV94 Jun 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually there's a bug in PyTensor Softmax, it also fails if you try to execute with integer dtype in the default backend. I'll open an issue for that.

For now it's enough for the torch dispatch function to raise a NotImplementedError if the input dtype (which similarly you can get from node.inputs[0].dtype is an integer, and not try to handle it inside the returned softmax function.

Same for the other Softmax related functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue: #857

@@ -40,11 +40,14 @@ def dimshuffle(x):
@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].outputs[0].dtype
Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe better to check the input dtype, because we would fail if we pass an integer. PyTensor could start saying Softmax takes as input integers and outputs floats once we fix it? Sorry if I said the output before

Suggested change
dtype = kwargs["node"].outputs[0].dtype
dtype = kwargs["node"].inputs[0].dtype

@ricardoV94
Copy link
Member

Left one comment, after that I think it's ready!

@HAKSOAT
Copy link
Contributor Author

HAKSOAT commented Jun 28, 2024

Hi @ricardoV94 , I have made that change and resolved existing merge conflicts.

Copy link

codecov bot commented Jun 28, 2024

Codecov Report

Attention: Patch coverage is 73.33333% with 8 lines in your changes missing coverage. Please review.

Project coverage is 80.98%. Comparing base (920b409) to head (b4cdce0).
Report is 148 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/elemwise.py 73.33% 8 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #846      +/-   ##
==========================================
- Coverage   80.98%   80.98%   -0.01%     
==========================================
  Files         169      169              
  Lines       46985    47025      +40     
  Branches    11494    11501       +7     
==========================================
+ Hits        38052    38084      +32     
- Misses       6716     6727      +11     
+ Partials     2217     2214       -3     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/elemwise.py 69.81% <73.33%> (+4.59%) ⬆️

... and 3 files with indirect coverage changes

@ricardoV94 ricardoV94 merged commit 17fa8b1 into pymc-devs:main Jun 28, 2024
56 of 57 checks passed
@ricardoV94
Copy link
Member

Thanks @HAKSOAT

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants