Skip to content

Commit

Permalink
temporarily switch off jax-ml#2414 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj authored and srvasude committed May 5, 2020
1 parent cd986eb commit 2d28359
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def fori_loop(lower, upper, body_fun, init_val):
except TypeError:
use_scan = False
else:
use_scan = True
use_scan = False # TODO(mattjj): re-enable this

if use_scan:
(_, _, result), _ = scan(_fori_scan_body_fun(body_fun),
Expand Down Expand Up @@ -1209,7 +1209,8 @@ def scan_bind(*args, forward, length, num_consts, num_carry, jaxpr, linear):
scan_p = core.Primitive("scan")
scan_p.multiple_results = True
scan_p.def_custom_bind(scan_bind)
scan_p.def_impl(partial(xla.apply_primitive, scan_p))
scan_p.def_impl(_scan_impl)
# scan_p.def_impl(partial(xla.apply_primitive, scan_p)) # TODO(mattjj): re-enable
scan_p.def_abstract_eval(_scan_abstract_eval)
ad.primitive_jvps[scan_p] = _scan_jvp
ad.primitive_transposes[scan_p] = _scan_transpose
Expand Down

0 comments on commit 2d28359

Please sign in to comment.