From 01d3bd5b593eddf03742dc74dd72178010266a8d Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sun, 21 Mar 2021 15:53:24 -0700 Subject: [PATCH] fix convert_element_type on large inputs --- jax/_src/lax/lax.py | 8 +++++++- jax/experimental/host_callback.py | 2 ++ tests/host_callback_test.py | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 4a2887045fb1..afb50ce7a736 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2658,6 +2658,12 @@ def _minmax_translation_rule(c, x, y, *, minmax=None, cmp=None): ad.defjvp_zero(lt_p) +def _convert_element_type_impl(operand, *, new_dtype, weak_type): + if not isinstance(operand, xla.DeviceArray): + operand = np.asarray(operand, dtype=new_dtype) + return xla.apply_primitive(convert_element_type_p, operand, + new_dtype=new_dtype, weak_type=weak_type) + def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): return operand.shape @@ -2693,7 +2699,7 @@ def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type): return convert_element_type_p.bind(tangent, new_dtype=new_dtype, weak_type=weak_type) convert_element_type_p = core.convert_element_type_p -convert_element_type_p.def_impl(partial(xla.apply_primitive, convert_element_type_p)) +convert_element_type_p.def_impl(_convert_element_type_impl) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index a84a23b0407f..961ad1cadd1b 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -934,6 +934,7 @@ def _outside_call_jvp_rule(primals, tangents, **params): if not params["identity"]: raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") tangent_instantiated = tuple(map(_instantiate_zeros, primals, tangents)) + tangent_instantiated = tuple(map(ad.replace_float0s, primals, tangent_instantiated)) arg_treedef = params["arg_treedef"] # The argument to the jvp tap is a pair of the tapped primals and tangents @@ -946,6 +947,7 @@ def _outside_call_jvp_rule(primals, tangents, **params): arg_treedef=jvp_arg_treedef, )) out_primals_tapped, out_tangents_tapped = util.split_list(out_all, [len(primals)]) + out_tangents_tapped = map(ad.recast_to_float0, out_primals_tapped, out_tangents_tapped) return tuple(out_primals_tapped), tuple(out_tangents_tapped) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index a0d74caca76f..932de7e0b3ad 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1028,7 +1028,7 @@ def func(x, yint): 2 ) transforms: ['jvp', 'transpose'] what: pair ( 2.00 - False )""", testing_stream.output) + 0 )""", testing_stream.output) testing_stream.reset() def test_tap_vmap(self):