Skip to content

Commit

Permalink
Temporarily make ShardedDeviceArray.__init__ optionally accept old si…
Browse files Browse the repository at this point in the history
…gnature.

This allows us to incrementally update ShardedDeviceArray creators to the new constructor introduced in 07571ae.
  • Loading branch information
skye committed Apr 15, 2020
1 parent 7a61aea commit eca941b
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,15 @@ def __init__(self,
sharding_spec: ShardingSpec,
device_buffers: List[xb.xla_client._xla.PyLocalBuffer],
indices: Optional[Tuple[Index, ...]] = None):
# TODO(skye): this is temporary staging while we switch users over to
# providing sharding_spec. It assumes that any pre-existing callers are
# creating pmap-style ShardedDeviceArrays.
if device_buffers is None:
assert isinstance(sharding_spec[0], xb.xla_client._xla.PyLocalBuffer)
device_buffers = sharding_spec
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
aval.shape[1:])

# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
Expand Down

0 comments on commit eca941b

Please sign in to comment.