Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid casting all terms to the same dtype in logp #7329

Closed
wants to merge 1 commit into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 22, 2024

Description

This change avoids many warnings when compiling a logp to JAX in pytensor.config.floatX = "float32". Since the terms are all scalars, it should also be more efficient, because it avoids an explict MakeVector.

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7329.org.readthedocs.build/en/7329/

@ricardoV94 ricardoV94 force-pushed the add_instead_of_sum branch from 70196cc to 50697be Compare May 22, 2024 16:23
Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this change affect int casting? Aren’t the logp terms already floatX?

@ricardoV94
Copy link
Member Author

Why does this change affect int casting? Aren’t the logp terms already floatX?

This change avoids explicit castings to float64. Some logps may be float32 and others float64, and if you sum there will be an explicit casting + make vector. If you do add there is only implicit casting and jax doesn't emit a warning about ignoring explicit float64 casts.

This has nothing to do with ints anymore, I have split those changes into another PR

Copy link
Contributor

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. The errors seem like a pytensor problem with ScalarType.upcast, so this PR should not be merged until pytensor sorts that out.

@ricardoV94
Copy link
Member Author

Not a PyTensor error, but a cryptic failure when add has no entries

@ricardoV94
Copy link
Member Author

Actually we have a rewrite that does this, but it's a bit conservative with dtypes... Maybe we can tweak this rewrite instead.

https://github.com/pymc-devs/pytensor/blob/f7b0a7a48b929605a083e13a12f144040a7fe265/pytensor/tensor/rewriting/basic.py#L919-L956

@ricardoV94 ricardoV94 closed this May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants