Skip to content

Commit 87272fb

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas/Fuser] Add debug option to fuser.fuse that prints out jaxpr
PiperOrigin-RevId: 735505460
1 parent affe2e7 commit 87272fb

File tree

1 file changed

+29
-15
lines changed

1 file changed

+29
-15
lines changed

jax/_src/pallas/fuser/jaxpr_fusion.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -32,28 +32,42 @@ def _get_aval(x):
3232
return jax_core.raise_to_shaped(jax_core.get_aval(x))
3333

3434

35-
def fuse(f, *, physicalize: bool = False):
35+
def fuse(f=None, *, physicalize: bool = False, debug: bool = False):
3636
"""Fuses a function into a single fusable.
3737
38+
Args:
39+
f: The function to fuse.
40+
physicalize: (experimental) whether to physicalize the function.
41+
debug: Whether to print debug information.
42+
3843
There should be a single call to a `fusable` inside the body of `f`. `fuse`
3944
returns a transformed function that will fuse the surrounding computation into
4045
the fusable and invoke it.
4146
"""
42-
def wrapper(*args, **kwargs):
43-
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
44-
debug_info = api_util.debug_info('fuse', f, args, kwargs)
45-
flat_fun, out_tree_thunk = api_util.flatten_fun(
46-
lu.wrap_init(f, debug_info=debug_info), in_tree
47-
)
48-
flat_avals = [_get_aval(x) for x in flat_args]
49-
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
50-
out_tree = out_tree_thunk()
51-
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
52-
return tree_util.tree_unflatten(out_tree, out_flat)
5347

54-
if physicalize:
55-
wrapper = fusable_dtype.physicalize(wrapper)
56-
return wrapper
48+
def decorator(f):
49+
def wrapper(*args, **kwargs):
50+
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
51+
debug_info = api_util.debug_info("fuse", f, args, kwargs)
52+
flat_fun, out_tree_thunk = api_util.flatten_fun(
53+
lu.wrap_init(f, debug_info=debug_info), in_tree
54+
)
55+
flat_avals = [_get_aval(x) for x in flat_args]
56+
jaxpr, _, consts, _ = pe.trace_to_jaxpr_dynamic(flat_fun, flat_avals)
57+
if debug:
58+
print("Jaxpr before fusion:")
59+
print(jaxpr)
60+
out_tree = out_tree_thunk()
61+
out_flat = fuse_jaxpr(jaxpr, out_tree, consts, *flat_args)
62+
return tree_util.tree_unflatten(out_tree, out_flat)
63+
64+
if physicalize:
65+
wrapper = fusable_dtype.physicalize(wrapper)
66+
return wrapper
67+
68+
if f is not None:
69+
return decorator(f)
70+
return decorator
5771

5872

5973
_fusable: dict[jax_core.Primitive, Any] = {}

0 commit comments

Comments
 (0)