Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perform shape checks for self.param AFTER unboxing #4079

Merged
merged 1 commit into from
Jul 23, 2024

Conversation

danielwatson6
Copy link
Contributor

This pull request will unbox params before performing shape checks in flax/core/scope.py, rather than after.

This should not affect pjit / jax.jit users as unboxing generally doesn't touch traced array data (and wasn't designed for it). However, users that prefer per-device code and manual collectives will run into trouble when implementing certain patterns like FSDP through their own boxes (i.e., subclassing nn.meta.AxisMetadata to achieve a "delayed all-gather") because the shape check fixed here happens at the wrong time.

Fixes # (issue) N/A

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion (please add a
    link).
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

All existing tests pass.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 4 lines in your changes missing coverage. Please review.

Project coverage is 0.00%. Comparing base (31adb00) to head (bc68908).
Report is 87 commits behind head on main.

Files Patch % Lines
flax/core/scope.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@          Coverage Diff           @@
##            main   #4079    +/-   ##
======================================
  Coverage   0.00%   0.00%            
======================================
  Files        106     107     +1     
  Lines      13582   13820   +238     
======================================
- Misses     13582   13820   +238     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@copybara-service copybara-service bot merged commit d718655 into google:main Jul 23, 2024
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants