-
Notifications
You must be signed in to change notification settings - Fork 38
feat: use parameter shardings from XLA #743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
forgot to export some names for mac. Will fix in next JLL. |
|
An even simpler sharding case that gives incorrect results # Currently an extremely simple test
using Reactant, Test
const addressable_devices = Reactant.addressable_devices()
mesh = Sharding.Mesh(reshape(collect(Int64, 0:3), (2, 2)), ("data", "model"))
# samples_sharding = Sharding.NamedSharding(mesh, (nothing, "data"))
w1_sharding = Sharding.NamedSharding(mesh, ("model", nothing))
# w2_sharding = Sharding.NamedSharding(mesh, ("data", nothing))
# samples = reshape(collect(Float32, 1:84), 7, 12)
w1 = reshape(collect(Float32, 1:4), 2, 2)
w2 = reshape(collect(Float32, 1:4), 2, 2)
w1_ra = Reactant.to_rarray(w1; sharding=w1_sharding)
w2_ra = Reactant.to_rarray(w2; sharding=w1_sharding)
@code_xla *(w2_ra, w1_ra)
# @jit *(w2_ra, w1_ra) |
|
foo(x) = x .+ x'
x = reshape(collect(Float32, 1:4), 2, 2)
x_ra = Reactant.to_rarray(
x;
sharding=Sharding.NamedSharding(
Sharding.Mesh(reshape(collect(Int64, 0:3), (2, 2)), ("data", "model")),
("data", nothing),
),
)
@code_xla foo(x_ra)
@jit foo(x_ra) |
| tmp = Reactant.ConcreteRArray( | ||
| ones(sharding_and_shape.shape); sharding=LazySharding(sharding_and_shape.sharding) | ||
| ) | ||
| _, exec, _, _, _ = Reactant.Compiler.compile_xla(internal_simple_op, (tmp,)) | ||
| return XLA.CondensedOpSharding(only(XLA.get_parameter_shardings(exec))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the most ideal solution, but is guaranteed to be correct. After GB I will see if there is a nicer way to do this
|
Locally tests pass. We need a new JLL before CI is green |
strangely enough dot_general is giving me incorrect results when using sharding. Every other operator I tested with the exact same sharding setup gives correct result.
Once the JLL builds I will test it out on a tpu pod to verify this isn't some weird behavior originating from
--xla_force_host_platform_device_count=8This should also unblock PRONTOLab/GB-25#8 (comment)