Skip to content

Commit

Permalink
[jax2tf] Add check that native lowering should not include custom cal…
Browse files Browse the repository at this point in the history
…ls not guaranteed to be stable.

PiperOrigin-RevId: 515245302
  • Loading branch information
gnecula authored and jax authors committed Mar 9, 2023
1 parent 50d8358 commit 5c91453
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 12 deletions.
42 changes: 34 additions & 8 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@
map = util.safe_map
zip = util.safe_zip

# These are the JAX custom call target names that are guaranteed to be stable.
_CUSTOM_CALL_TARGETS_GUARANTEED_STABLE = [
"Sharding", "SPMDFullToShardShape", "SPMDShardToFullShape"
]

def _sanitize_scope_name(name):
scope_name = _INVALID_SCOPE_CHAR.sub("_", name)
Expand Down Expand Up @@ -219,7 +223,7 @@ def convert(fun_jax: Callable,
enable_xla=True,
experimental_native_lowering="default",
experimental_native_lowering_platforms=(),
experimental_native_lowering_strict_checks=False) -> Callable:
experimental_native_lowering_strict_checks=True) -> Callable:
"""Lowers `fun_jax` into a function that uses only TensorFlow ops.
See
Expand Down Expand Up @@ -287,7 +291,8 @@ def convert(fun_jax: Callable,
experimental_native_lowering_strict_checks: DO NOT USE, for experimental purposes only.
In conjunction with `experimental_native_lowering`, enable the following
checks: the lowered computation is executed on a platform for which it
was lowered, (more to come).
was lowered, the serialized computation contains only custom calls with
targets that are guaranteed to be stable, (more to come).
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
Expand Down Expand Up @@ -823,9 +828,10 @@ def _out_type(jax_type):
log_msg,
mlir_module_text)

if not allow_non_replicated_sharding:
check_module(mlir_module,
allow_non_replicated_sharding=allow_non_replicated_sharding)
# Check the module after we logged it.
check_module(mlir_module,
allow_non_replicated_sharding=allow_non_replicated_sharding,
allow_all_custom_calls=not lowering_params.experimental_native_lowering_strict_checks)

res = tfxla.call_module(args_tf, **call_module_attrs)
if "out_shardings" in lowered.compile_args:
Expand All @@ -847,16 +853,23 @@ def _convert_res(res_val, res_jax_type):


def check_module(mod: mlir.ir.Module, *,
allow_non_replicated_sharding: bool):
allow_non_replicated_sharding: bool,
allow_all_custom_calls: bool):
"""Run a number of checks on the module.
TODO: check for custom calls.
Args:
allow_non_replicated_sharding: whether the module is allowed to contain
non_replicated sharding annotations.
allow_all_custom_calls: whether we should allow all custom calls, or
only those who we have explicitly marked as stable.
"""
sharding_attr = mlir.ir.StringAttr.get("Sharding", mod.context)
allowed_custom_call_targets_attrs = [
mlir.ir.StringAttr.get(target, mod.context)
for target in _CUSTOM_CALL_TARGETS_GUARANTEED_STABLE]
disallowed_custom_call_ops: List[str] = []
def check_sharding(op_str: str, loc: mlir.ir.Location):
# Check the shardings in an operation or attribute (`op_str`)
if not allow_non_replicated_sharding:
Expand All @@ -874,8 +887,12 @@ def check_op(op: mlir.ir.Operation):
check_sharding(str(a), op.location)

elif op_name == "stablehlo.custom_call":
if op.operation.attributes["call_target_name"] == sharding_attr:
check_sharding(str(op), op.location )
call_target_name_attr = op.operation.attributes["call_target_name"]
if (not allow_all_custom_calls and
call_target_name_attr not in allowed_custom_call_targets_attrs):
disallowed_custom_call_ops.append(str(op))
if call_target_name_attr == sharding_attr:
check_sharding(str(op), op.location)

def walk_operations(op):
check_op(op)
Expand All @@ -885,6 +902,15 @@ def walk_operations(op):
walk_operations(op)

walk_operations(mod)
if disallowed_custom_call_ops:
disallowed_custom_call_ops_str = "\n".join(disallowed_custom_call_ops)
msg = ("Cannot serialize code with custom calls whose targets have no "
"compatibility guarantees. Examples are:\n"
f"{disallowed_custom_call_ops_str}.\n"
"If you know what you are doing you can disable this check by "
"setting `experimental_native_lowering_strict_checks` to "
"`False`.")
raise ValueError(msg)


def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
Expand Down
11 changes: 7 additions & 4 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,14 +1083,17 @@ def test_device_array_arg(self):
self.ConvertAndCompare(jnp.sin, jnp.zeros((2, 3), jnp.float32))

def test_randint(self):
if jtu.device_under_test() == "gpu" and config.jax2tf_default_experimental_native_lowering:
raise unittest.SkipTest("randint on GPU uses custom calls; not supported")

def randint():
return jax.random.randint(
jax.random.PRNGKey(42), shape=(), minval=0, maxval=1)

self.ConvertAndCompare(randint)
with contextlib.ExitStack() as stack:
if (jtu.device_under_test() == "gpu" and
config.jax2tf_default_experimental_native_lowering):
stack.enter_context(
self.assertRaisesRegex(ValueError,
"Cannot serialize code with custom calls whose targets .*"))
self.ConvertAndCompare(randint)

def test_op_metadata_simple(self):
self.skipTest("include_xla_op_metadata not yet enabled")
Expand Down
53 changes: 53 additions & 0 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,59 @@ def test_prim(self, harness: primitive_harness.Harness):
"lu" in harness.fullname):
raise unittest.SkipTest("b/269388847: lu failures on GPU")

def skipCustomCallTest(target: str):
raise unittest.SkipTest(
f"TODO(b/272239584): custom call target not guaranteed stable: {target}")
if config.jax2tf_default_experimental_native_lowering:
if device == "cpu":
if "cholesky_shape" in harness.fullname:
skipCustomCallTest("lapack_spotrf, lapack_dpotrf, lapack_zpotrf, lapack_cpotrf")
if "eig_shape" in harness.fullname:
skipCustomCallTest("lapack_cgeev, lapack_sgeev, lapack_dgeev, lapack_zgeev")
if "eigh_shape" in harness.fullname:
skipCustomCallTest("lapack_cheevd, lapack_ssyevd, lapack_zheevd")
if "lu_shape" in harness.fullname:
skipCustomCallTest("lapack_zgetrf, lapack_sgetrf")
if "svd_shape" in harness.fullname:
skipCustomCallTest("lapack_sgesdd, lapack_zgesdd, lapack_cgesdd")
if "qr_" in harness.fullname:
skipCustomCallTest("lapack_dgeqrf, lapack_cgeqrf, lapack_zgeqrf")
if "triangular_solve_" in harness.fullname:
skipCustomCallTest("blas_ctrsm, blas_dtrsm, blas_ztrsm, blas_strsm")
if "fft_" in harness.fullname:
skipCustomCallTest("ducc_fft")
if "custom_linear_solve" in harness.fullname:
skipCustomCallTest("lapack_sgetrf, lapack_dgetrf")

elif device == "tpu":
if "qr_" in harness.fullname:
skipCustomCallTest("Qr")
if "svd_shape" in harness.fullname:
skipCustomCallTest("Qr")
if "lu_shape" in harness.fullname:
skipCustomCallTest("LuDecomposition")
if "custom_linear_solve_" in harness.fullname:
skipCustomCallTest("LuDecomposition")
if "eigh_shape" in harness.fullname:
skipCustomCallTest("Eigh")

elif device == "gpu":
if "eigh_shape" in harness.fullname:
skipCustomCallTest("cusolver_syevj")
if ("qr_" in harness.fullname or
"custom_linear_solve_" in harness.fullname):
skipCustomCallTest("cusolver_geqrf, cublas_geqrf_batched")
if "svd_shape" in harness.fullname:
skipCustomCallTest("cusolver_gesvdj")
if ("random_split_" in harness.fullname or
"random_gamma_" in harness.fullname or
"random_uniform_" in harness.fullname or
"random_categorical_" in harness.fullname or
"random_randint" in harness.fullname):
skipCustomCallTest("cu_threefry2x32")
if "tridiagonal_solve_shape" in harness.fullname:
skipCustomCallTest("cusparse_gtsv2_f32, cusparse_gtsv2_f64")

associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
try:
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
Expand Down

0 comments on commit 5c91453

Please sign in to comment.