diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 545de95da0f5..f2fa94a82da5 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -447,9 +447,18 @@ class ShardedDeviceArray(xla.DeviceArray): # TODO(skye): expose PyLocalBuffers in xla_client def __init__(self, aval: ShapedArray, - sharding_spec: ShardingSpec, - device_buffers: List[xb.xla_client._xla.PyLocalBuffer], + sharding_spec, # TODO(skye): add type annotation back, see below + device_buffers: List[xb.xla_client._xla.PyLocalBuffer] = None, indices: Optional[Tuple[Index, ...]] = None): + # TODO(skye): this is temporary staging while we switch users over to + # providing sharding_spec. It assumes that any pre-existing callers are + # creating pmap-style ShardedDeviceArrays. + if device_buffers is None: + assert isinstance(sharding_spec[0], xb.xla_client._xla.PyLocalBuffer) + device_buffers = sharding_spec + sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], + aval.shape[1:]) + # TODO(skye): assert invariants. Keep performance in mind though. if indices is None: indices = spec_to_indices(aval.shape, sharding_spec)