-
Notifications
You must be signed in to change notification settings - Fork 38
feat: sharding with non-divisible dimensions [alternate approach] #825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3e10c82 to
11a38ed
Compare
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
This comment was marked as outdated.
Collaborator
Author
|
https://github.com/openxla/xla/blob/a83a2d3f1f977e4825cb210320597c5825c25ead/xla/client/executable_build_options.h#L327-L330 are false by default. so this is a bit unexpected. module @reactant_fn_test attributes {mhlo.frontend_attributes = {xla.sdy.meshes = "{mesh = #sdy.mesh<[\\\22data\\\22=2, \\\22model\\\22=4]>}"}, mhlo.num_partitions = 8 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<7x2xf32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22model\\\22}, {\\\22data\\\22}]>"}, mhlo.sharding = "{devices=[4,2]<=[2,4]T(1,0)}"}) -> tensor<7x1xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = mhlo.reduce(%arg0 init: %0) applies mhlo.add across dimensions = [1] : (tensor<7x2xf32>, tensor<f32>) -> tensor<7xf32>
%2 = mhlo.reshape %1 : (tensor<7xf32>) -> tensor<7x1xf32>
return %2 : tensor<7x1xf32>
}
}MHLO sharding: |
wsmoses
reviewed
Feb 28, 2025
5b0a5e1 to
81f1477
Compare
81f1477 to
e70f4c2
Compare
Collaborator
Author
|
@wsmoses is this good to go from your end? |
wsmoses
approved these changes
Mar 4, 2025
avik-pal
added a commit
that referenced
this pull request
Mar 4, 2025
* feat: support implicit padding from XLA * feat: use XLA for shard-info if we need padding * test: padding for sharding * fix: return type
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Unless I am missing something obvious, XLA shouldn't be modifying input shardings unless we opt in for that:
Case I: Divisible Dimensions
As expected the inputs are of size [2, 1]
Case II: Non-Divisible Dimensions
Expected input dims should be [2, 1] (with 2 replicas requiring padded inputs)