Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporarily make ShardedDeviceArray.__init__ optionally accept old si… #2730

Merged
merged 2 commits into from
Apr 16, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,18 @@ class ShardedDeviceArray(xla.DeviceArray):
# TODO(skye): expose PyLocalBuffers in xla_client
def __init__(self,
aval: ShapedArray,
sharding_spec: ShardingSpec,
device_buffers: List[xb.xla_client._xla.PyLocalBuffer],
sharding_spec, # TODO(skye): add type annotation back, see below
device_buffers: List[xb.xla_client._xla.PyLocalBuffer] = None,
indices: Optional[Tuple[Index, ...]] = None):
# TODO(skye): this is temporary staging while we switch users over to
# providing sharding_spec. It assumes that any pre-existing callers are
# creating pmap-style ShardedDeviceArrays.
if device_buffers is None:
assert isinstance(sharding_spec[0], xb.xla_client._xla.PyLocalBuffer)
device_buffers = sharding_spec
sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0],
aval.shape[1:])

# TODO(skye): assert invariants. Keep performance in mind though.
if indices is None:
indices = spec_to_indices(aval.shape, sharding_spec)
Expand Down