Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Softmax fails with integer dtypes only at runtime #857

Open
ricardoV94 opened this issue Jun 26, 2024 · 1 comment
Open

Softmax fails with integer dtypes only at runtime #857

ricardoV94 opened this issue Jun 26, 2024 · 1 comment
Labels
bug Something isn't working Op implementation

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 26, 2024

Description

Brought up in #846

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", dtype="int64")
out = pt.special.softmax(x)

# Doesn't seem right
out.dprint(print_type=True)
# Softmax{axis=None} [id A] <Vector(int64, shape=(?,))>
# └─ x [id B] <Vector(int64, shape=(?,))>

# No complaints
fn = pytensor.function([x], out)

fn([1, 2, 3])  # TypeError: not a float

We should either raise at graph definition time, or cast the input to float. Scipy is happy to take integers (and return floats), so we could try to do the same.

@ricardoV94 ricardoV94 added bug Something isn't working Op implementation labels Jun 26, 2024
@ricardoV94 ricardoV94 mentioned this issue Jun 26, 2024
11 tasks
@ricardoV94
Copy link
Member Author

This problem will go away if we use OpFromGraph to represent the Softmax, as exp(integers) is well defined for those Operations

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Op implementation
Projects
None yet
Development

No branches or pull requests

1 participant