@@ -32,28 +32,42 @@ def _get_aval(x):
32
32
return jax_core .raise_to_shaped (jax_core .get_aval (x ))
33
33
34
34
35
- def fuse (f , * , physicalize : bool = False ):
35
+ def fuse (f = None , * , physicalize : bool = False , debug : bool = False ):
36
36
"""Fuses a function into a single fusable.
37
37
38
+ Args:
39
+ f: The function to fuse.
40
+ physicalize: (experimental) whether to physicalize the function.
41
+ debug: Whether to print debug information.
42
+
38
43
There should be a single call to a `fusable` inside the body of `f`. `fuse`
39
44
returns a transformed function that will fuse the surrounding computation into
40
45
the fusable and invoke it.
41
46
"""
42
- def wrapper (* args , ** kwargs ):
43
- flat_args , in_tree = tree_util .tree_flatten ((args , kwargs ))
44
- debug_info = api_util .debug_info ('fuse' , f , args , kwargs )
45
- flat_fun , out_tree_thunk = api_util .flatten_fun (
46
- lu .wrap_init (f , debug_info = debug_info ), in_tree
47
- )
48
- flat_avals = [_get_aval (x ) for x in flat_args ]
49
- jaxpr , _ , consts , _ = pe .trace_to_jaxpr_dynamic (flat_fun , flat_avals )
50
- out_tree = out_tree_thunk ()
51
- out_flat = fuse_jaxpr (jaxpr , out_tree , consts , * flat_args )
52
- return tree_util .tree_unflatten (out_tree , out_flat )
53
47
54
- if physicalize :
55
- wrapper = fusable_dtype .physicalize (wrapper )
56
- return wrapper
48
+ def decorator (f ):
49
+ def wrapper (* args , ** kwargs ):
50
+ flat_args , in_tree = tree_util .tree_flatten ((args , kwargs ))
51
+ debug_info = api_util .debug_info ("fuse" , f , args , kwargs )
52
+ flat_fun , out_tree_thunk = api_util .flatten_fun (
53
+ lu .wrap_init (f , debug_info = debug_info ), in_tree
54
+ )
55
+ flat_avals = [_get_aval (x ) for x in flat_args ]
56
+ jaxpr , _ , consts , _ = pe .trace_to_jaxpr_dynamic (flat_fun , flat_avals )
57
+ if debug :
58
+ print ("Jaxpr before fusion:" )
59
+ print (jaxpr )
60
+ out_tree = out_tree_thunk ()
61
+ out_flat = fuse_jaxpr (jaxpr , out_tree , consts , * flat_args )
62
+ return tree_util .tree_unflatten (out_tree , out_flat )
63
+
64
+ if physicalize :
65
+ wrapper = fusable_dtype .physicalize (wrapper )
66
+ return wrapper
67
+
68
+ if f is not None :
69
+ return decorator (f )
70
+ return decorator
57
71
58
72
59
73
_fusable : dict [jax_core .Primitive , Any ] = {}
0 commit comments