Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

revise xla.device_put device logic #2907

Merged
merged 2 commits into from
May 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections import defaultdict
import itertools as it
import operator as op
from typing import Any, Callable, Dict, Sequence, Type
from typing import Any, Callable, Dict, Sequence, Type, Optional

from absl import logging
import numpy as onp
Expand Down Expand Up @@ -920,16 +920,23 @@ def _device_put_device_array(x, device):
return _force(x).device_buffer
device_put_handlers[DeviceArray] = _device_put_device_array

def _copy_device_array_to_device(x, device):
if is_device_constant(x):
def _copy_device_array_to_device(x: DeviceArray, device: Optional[xc.Device]):
if device is None:
# no copying to be done because there's no target specified
return x
elif is_device_constant(x):
# create a new DeviceArray with the same lazy expr, no copying
return DeviceArray(x.aval, device, x._lazy_expr, DeviceConstant(device))
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
if device is None or x.device_buffer.device() == device:
# source and target platforms are the same
if x.device_buffer.device() == device:
# no copying to be done because source equals target
return x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is not quite right, because if x is uncommitted to the device, it should be now committed to the device. This case is covered in upcoming #2882

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch, thanks!

else:
# move the buffer with a device-to-device copy
moved_buf = x.device_buffer.copy_to_device(device)
else:
# Buffers from different XLA backends are passed through the host.
# buffers from different XLA backends are passed through the host.
backend = xb.get_device_backend(device)
moved_buf = backend.buffer_from_pyval(x.device_buffer.to_py(), device)
return DeviceArray(x.aval, device, x._lazy_expr, moved_buf)
Expand Down
3 changes: 0 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,6 @@ def test_device_put_across_platforms(self):
x = api.device_put(val, device=cpu_device)
self.assertEqual(x.device_buffer.device(), cpu_device)

y = api.device_put(x)
self.assertEqual(y.device_buffer.device(), default_device)

def test_jit_on_all_devices(self):
# Verifies we can run the same computation on every device present, even
# if they are, for example, different models of GPU.
Expand Down
18 changes: 18 additions & 0 deletions tests/multibackend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,24 @@ def my_sin(x): return np.sin(x)
result4 = api.jit(my_sin, backend="cpu")(2)
self.assertEqual(result4.device_buffer.device(), cpus[0])

@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_indexing(self):
# https://github.com/google/jax/issues/2905
cpus = api.devices("cpu")

x = api.device_put(onp.ones(2), cpus[0])
y = x[0]
self.assertEqual(y.device_buffer.device(), cpus[0])

@jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends
def test_sum(self):
# https://github.com/google/jax/issues/2905
cpus = api.devices("cpu")

x = api.device_put(onp.ones(2), cpus[0])
y = x.sum()
self.assertEqual(y.device_buffer.device(), cpus[0])


if __name__ == "__main__":
absltest.main()