Skip to content

Commit

Permalink
Fix a bug where full and use_mesh outside jit did not work becaus…
Browse files Browse the repository at this point in the history
…e the `shard` passed to `make_array_from_callback` was sharded on all devices instead of just 1 device.

This is because `convert_element_type` returning an output on all devices of the mesh because of the surrounding `use_mesh` context.

PiperOrigin-RevId: 735909962
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Mar 11, 2025
1 parent 29bfd00 commit f45cbf3
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3016,6 +3016,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
broadcast_shape = sharding.shard_shape(shape)
shard = broadcast(fill_value, broadcast_shape)
shard = shard.addressable_data(0)
return array.make_array_from_callback(shape, sharding, lambda _: shard)

if sharding is not None and not sharding._is_concrete:
Expand Down

0 comments on commit f45cbf3

Please sign in to comment.