Skip to content

Commit

Permalink
Temporarily make ShardedDeviceArray.__init__ optionally accept old si… (
Browse files Browse the repository at this point in the history
jax-ml#2730)

This allows us to incrementally update ShardedDeviceArray creators to the new constructor introduced in jax-ml@07571ae.
  • Loading branch information
skye authored and jacobjinkelly committed Apr 21, 2020
1 parent f5543b0 commit a621b79
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,18 @@ class ShardedDeviceArray(xla.DeviceArray):
# TODO(skye): expose PyLocalBuffers in xla_client
def __init__(self,
aval: ShapedArray,
sharding_spec: ShardingSpec,
device_buffers: List[xb.xla_client._xla.PyLocalBuffer],
sharding_spec, # TODO(skye): add type annotation back, see below
device_buffers: List[xb.xla_client._xla.PyLocalBuffer] = None,
indices: Optional[Tuple[Index, ...]] = None):
# TODO(skye): this is temporary staging while we switch users over to
# providing sharding_spec. It assumes that any pre-existing callers are
# creating pmap-style ShardedDeviceArrays.
if device_buffers is None:
assert isinstance(sharding_spec[0], xb.xla_client._xla.PyLocalBuffer)
device_buffers = sharding_spec
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
aval.shape[1:])

# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
Expand Down

0 comments on commit a621b79

Please sign in to comment.