Skip to content

Commit

Permalink
Fixed a few more places where device commitment was lost.
Browse files Browse the repository at this point in the history
* trivial jit computations were forcing commitment to the default device
* a device_put with a device specification would not set the commitment
  if the data was already (uncommitted) on the specified device.
* added tests for the above
* once the above were fixed the LaztTest.test_zeros_ones_compilation
  stated to fail because the `sticky` parameter to lazy_force_computation
  was changing. Fixed this by removing stickyness from the compilation key.
* Expanded docstring for jax.device_put; expanded the
  device placement FAQ entry.
  • Loading branch information
gnecula committed May 1, 2020
1 parent ac023bf commit 687da66
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 24 deletions.
5 changes: 5 additions & 0 deletions docs/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ committed device, and the result will be committed on the
same device. It is an error to invoke an operation on
arguments that are committed to more than one device.

You can also use :func:`jax.device_put` without a ``device`` parameter,
in which case the data is left as is if already on a device (whether
committed or not), or a Python value that is not on any device is
placed uncommitted on the default device.

Jitted functions behave as any other primitive operation
(will follow the data and will error if invoked on data
committed on more than one device).
Expand Down
14 changes: 10 additions & 4 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1524,16 +1524,22 @@ def jaxpr_maker(*args, **kwargs):
return jaxpr_maker


def device_put(x, device=None):
def device_put(x, device: Optional[xc.Device] = None):
"""Transfers ``x`` to ``device``.
Args:
``x``: An array, scalar, or (nested) standard Python container thereof.
``device``: The ``Device`` to transfer ``x`` to.
``device``: The (optional) ``Device`` to transfer ``x`` to.
If given, then the result is committed to the device.
If the ``device`` parameter is ``None``, then this operation behaves like the
identity function if the operand is on any device already, otherwise it
transfers the data to the default device, uncommitted.
For more details on data placement see the https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices.
Returns:
A copy of ``x`` that resides on ``device``. If ``x`` is already on
``device``, returns ``x``.
A copy of ``x`` that resides on ``device``.
"""
return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)

Expand Down
28 changes: 14 additions & 14 deletions jax/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# Types
Backend = Any # xc.LocalBackend (why does mypy not like this?)
Device = Any # xc.Device
PyLocalBuffer = Any

FLAGS = flags.FLAGS
flags.DEFINE_bool('jax_debug_nans',
Expand Down Expand Up @@ -494,7 +495,6 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, *arg_specs):
# which are often produced from partial evaluation, don't need compilation,
# and don't need to force their (potentially lazy) arguments.
if not jaxpr.eqns:
device = device or xb.get_backend(None).get_default_device_assignment(1)[0]
return partial(_execute_trivial, jaxpr, device, consts, result_handlers)

log_priority = logging.WARNING if FLAGS.jax_log_compiles else logging.DEBUG
Expand Down Expand Up @@ -587,7 +587,7 @@ def _execute_replicated(compiled, handlers, *args):
if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
return [handler(out_buf) for handler, out_buf in zip(handlers, out_bufs)]

def _execute_trivial(jaxpr, device, consts, handlers, *args):
def _execute_trivial(jaxpr, device: Optional[Device], consts, handlers, *args):
env = {core.unitvar: core.unit}
_map(env.setdefault, jaxpr.invars, args)
_map(env.setdefault, jaxpr.constvars, consts)
Expand Down Expand Up @@ -755,7 +755,7 @@ class DeviceArray(DeviceValue):
__array_priority__ = 100

def __init__(self, aval: core.ShapedArray, device: Optional[Device],
lazy_expr, device_buffer):
lazy_expr: lazy.LazyExpr, device_buffer: PyLocalBuffer):
self.aval = aval
self.device_buffer = device_buffer
self._device = device
Expand Down Expand Up @@ -938,7 +938,10 @@ def _copy_device_array_to_device(x: DeviceArray, device: Optional[xc.Device]) ->
# source and target platforms are the same
if x.device_buffer.device() == device:
# no copying to be done because source equals target
return x
if x._device == device:
return x
else:
moved_buf = x.device_buffer # We need to change stickyness
else:
# move the buffer with a device-to-device copy
moved_buf = x.device_buffer.copy_to_device(device)
Expand All @@ -955,17 +958,16 @@ def _force(x: DeviceArray) -> DeviceArray:
# force x on the device where it lives, but preserve stickiness on result
if x._device:
device = x._device
sticky = True
else:
device = x.device_buffer.device()
sticky = False
force_fun = _lazy_force_computation(sticky, x.aval, device, x._lazy_expr)
return force_fun(x)
force_fun = _lazy_force_computation(x.aval, device, x._lazy_expr)
result = force_fun(x)
return DeviceArray(x.aval, x._device, lazy.array(x.aval.shape), result)

@cache()
def _lazy_force_computation(sticky: bool, aval: core.ShapedArray,
def _lazy_force_computation(aval: core.ShapedArray,
device: Device, lexpr: lazy.LazyExpr
) -> Callable[[DeviceArray], DeviceArray]:
) -> Callable[[DeviceArray], PyLocalBuffer]:
c = xb.make_computation_builder("lazy_force")
if lazy.is_constant(lexpr):
param = None
Expand All @@ -986,15 +988,13 @@ def _lazy_force_computation(sticky: bool, aval: core.ShapedArray,
backend = xb.get_device_backend(device)
compiled = backend.compile(built_c, compile_options=options)

result_device = device if sticky else None
handler = partial(DeviceArray, aval, result_device, lazy.array(aval.shape))
force_fun: Callable[[DeviceValue], DeviceArray]
if lazy.is_constant(lexpr):
def force_fun(_):
return handler(compiled.Execute([])[0])
return compiled.Execute([])[0]
else:
def force_fun(x):
return handler(compiled.Execute([x.device_buffer])[0])
return compiled.Execute([x.device_buffer])[0]
return force_fun


Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,7 +2043,7 @@ def test_zeros_ones_compilation(self):
x = np.ones(3) + np.zeros(3)
y = np.ones(3) + np.ones(3)

self.assertEqual(count[0], 1)
self.assertEqual(1, count[0])
self.assertAllClose(x, onp.ones(3), check_dtypes=False)
self.assertAllClose(y, onp.ones(3) + onp.ones(3), check_dtypes=False)

Expand Down
26 changes: 21 additions & 5 deletions tests/multi_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,33 @@ def test_computation_follows_data(self):
"primitive arguments must be colocated on the same device"):
jit_add(jax.device_put(x, devices[2]), jax.device_put(x, devices[3]))

# A jitted computation with a device specification behaves as if the
# Even jit of trivial computations leaves the result uncommitted
x_uncommitted = np.array([1, 2, 3])
y = jax.jit(lambda x: x)(x_uncommitted)
self.assert_uncommitted_to_device(y, devices[0])

z1, z2 = jax.jit(lambda x: (x, x))(x_uncommitted)
self.assert_uncommitted_to_device(z1, devices[0])
self.assert_uncommitted_to_device(z2, devices[0])
self.assertIs(z1, z2)

x2_uncommitted = np.array([2, 3])
z1, z2, z3 = jax.jit(lambda x, y: (y, 1, x))(x_uncommitted, x2_uncommitted)
self.assert_uncommitted_to_device(z1, devices[0])
self.assertIs(z2, 1)
self.assert_uncommitted_to_device(z3, devices[0])


# A jitted computation with an device specification behaves as if the
# arguments are first device_put to the specified device. The result
# will be committed on the specified.
# The `device` parameter is experimental, and subject to change.
jit_add_on4 = jax.jit(lambda a, b: a + b, device=devices[4])
self.assert_committed_to_device(jit_add_on4(1, 2), devices[4])
self.assert_committed_to_device(jit_add_on4(1, jax.device_put(2, devices[2])),
devices[4])
self.assert_committed_to_device(jit_add_on4(jax.device_put(x, devices[2]),
jax.device_put(x, devices[3])),
self.assert_committed_to_device(jit_add_on4(jax.device_put(x_uncommitted, devices[2]),
jax.device_put(x_uncommitted, devices[3])),
devices[4])

def test_primitive_compilation_cache(self):
Expand Down Expand Up @@ -146,8 +163,7 @@ def test_device_put(self):

# test device_put on lazy values
x = jax.device_put(np.zeros(2), device=devices[0])
# TODO(necula): re-enable this check
# self.assert_committed_to_device(x, devices[0])
self.assert_committed_to_device(x, devices[0])

y = jax.device_put(x, device=devices[1])
self.assert_committed_to_device(y, devices[1])
Expand Down

0 comments on commit 687da66

Please sign in to comment.