-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix large constants in activations #1640
Conversation
Thanks for this! Looks like there were some test failures in 64bit mode. |
Turns out lax.select fails because in X64 mode a python float * float32 tensor = float64 tensor. Is that not a violation of numpy dtype promotion rules? |
I'm hesitant to merge this just because it makes the code uglier and we hope #1668 (or a related PR) will fix the underlying issue. But at the same time this is fixing a real problem! @jekbradbury WDYT? Should we merge, maybe with TODO notes to revert it once we fix the underlying issues? |
I think I’d rather merge a version that uses The version in this PR depends on the specific current behavior where lax.full is lazy but broadcasts aren’t, which makes it a little opaque. |
Before this change, inner jitted functions wouldn't necessarily be fully staged out into an outer-jit trace; instead, as much as possible would be hoisted out of the inner jit. That led to extra constants getting materialized in #1640. For example: ```python @jit def f(x, y): z = 2 * x return y + z @jit def g(x): return f(2, x) g(3) ``` would lead to these XLA computations being compiled and executed: ``` HloModule jit_f.7 ENTRY jit_f.7 { parameter.2 = () parameter(1) tuple.3 = () tuple() parameter.1 = s32[] parameter(0) constant.4 = s32[] constant(2) multiply.5 = s32[] multiply(parameter.1, constant.4) ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5) } HloModule jit_g.14 jaxpr_subcomputation.4 { parameter.6 = () parameter(1) tuple.8 = () tuple() parameter.7 = s32[] parameter(2) parameter.5 = s32[] parameter(0) add.9 = s32[] add(parameter.7, parameter.5) ROOT tuple.10 = (s32[]) tuple(add.9) } ENTRY jit_g.14 { constant.1 = s32[] constant(4) tuple.3 = () tuple() parameter.2 = s32[] parameter(0) call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4 get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0 ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12) } ``` Notice that the `multiply` is separated out from the `add`, and in particular the XLA computation underlying `g` only has the `add` in it. This behavior was desirable when using partial evaluation for reverse-mode autodiff, since in that case we want to partially evaluate all the primal values underneath a call while staging out a jaxpr for the tangent values. But it was undesirable for the other use of partial evaluation, namely forming jaxprs under `jit` (and `pmap`). The solution was just to tag jaxpr traces differently in the two cases.
Before this change, inner jitted functions wouldn't necessarily be fully staged out into an outer-jit trace; instead, as much as possible would be hoisted out of the inner jit. That led to extra constants getting materialized in #1640. For example: ```python @jit def f(x, y): z = 2 * x return y + z @jit def g(x): return f(2, x) g(3) ``` would lead to these XLA computations being compiled and executed: ``` HloModule jit_f.7 ENTRY jit_f.7 { parameter.2 = () parameter(1) tuple.3 = () tuple() parameter.1 = s32[] parameter(0) constant.4 = s32[] constant(2) multiply.5 = s32[] multiply(parameter.1, constant.4) ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5) } HloModule jit_g.14 jaxpr_subcomputation.4 { parameter.6 = () parameter(1) tuple.8 = () tuple() parameter.7 = s32[] parameter(2) parameter.5 = s32[] parameter(0) add.9 = s32[] add(parameter.7, parameter.5) ROOT tuple.10 = (s32[]) tuple(add.9) } ENTRY jit_g.14 { constant.1 = s32[] constant(4) tuple.3 = () tuple() parameter.2 = s32[] parameter(0) call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4 get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0 ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12) } ``` Notice that the `multiply` is separated out from the `add`, and in particular the XLA computation underlying `g` only has the `add` in it. This behavior was desirable when using partial evaluation for reverse-mode autodiff, since in that case we want to partially evaluate all the primal values underneath a call while staging out a jaxpr for the tangent values. But it was undesirable for the other use of partial evaluation, namely forming jaxprs under `jit` (and `pmap`). The solution was just to tag jaxpr traces differently in the two cases.
This fixes an issue where using the elu, selu, or tanh activation results in a constant as big as the input being created.
This PR also switches the use of np.where to lax.select for activation functions to avoid accidental broadcasting of scalars in the future.
I added a test that verifies that making the jaxpr does not produce constants that scale in size with the input.