Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix large constants in activations #1640

Closed
wants to merge 3 commits into from

Conversation

jheek
Copy link
Contributor

@jheek jheek commented Nov 7, 2019

This fixes an issue where using the elu, selu, or tanh activation results in a constant as big as the input being created.

This PR also switches the use of np.where to lax.select for activation functions to avoid accidental broadcasting of scalars in the future.

I added a test that verifies that making the jaxpr does not produce constants that scale in size with the input.

@mattjj
Copy link
Collaborator

mattjj commented Nov 9, 2019

Thanks for this! Looks like there were some test failures in 64bit mode.

@jheek
Copy link
Contributor Author

jheek commented Nov 12, 2019

Turns out lax.select fails because in X64 mode a python float * float32 tensor = float64 tensor. Is that not a violation of numpy dtype promotion rules?

@mattjj
Copy link
Collaborator

mattjj commented Nov 14, 2019

I'm hesitant to merge this just because it makes the code uglier and we hope #1668 (or a related PR) will fix the underlying issue. But at the same time this is fixing a real problem!

@jekbradbury WDYT? Should we merge, maybe with TODO notes to revert it once we fix the underlying issues?

@jekbradbury
Copy link
Contributor

I think I’d rather merge a version that uses tie_in to avoid materializing broadcast constants, if we can get that to work (I can look into it), or wait until #1668 if we expect that to land soon.

The version in this PR depends on the specific current behavior where lax.full is lazy but broadcasts aren’t, which makes it a little opaque.

@mattjj mattjj mentioned this pull request Dec 5, 2019
11 tasks
mattjj added a commit that referenced this pull request Dec 12, 2019
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.

For example:

```python
@jit
def f(x, y):
    z = 2 * x
    return y + z

@jit
def g(x):
    return f(2, x)

g(3)
```

would lead to these XLA computations being compiled and executed:

```
HloModule jit_f.7

ENTRY jit_f.7 {
  parameter.2 = () parameter(1)
  tuple.3 = () tuple()
  parameter.1 = s32[] parameter(0)
  constant.4 = s32[] constant(2)
  multiply.5 = s32[] multiply(parameter.1, constant.4)
  ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}

HloModule jit_g.14

jaxpr_subcomputation.4 {
  parameter.6 = () parameter(1)
  tuple.8 = () tuple()
  parameter.7 = s32[] parameter(2)
  parameter.5 = s32[] parameter(0)
  add.9 = s32[] add(parameter.7, parameter.5)
  ROOT tuple.10 = (s32[]) tuple(add.9)
}

ENTRY jit_g.14 {
  constant.1 = s32[] constant(4)
  tuple.3 = () tuple()
  parameter.2 = s32[] parameter(0)
  call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
  get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
  ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```

Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.

This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).

The solution was just to tag jaxpr traces differently in the two cases.
mattjj added a commit that referenced this pull request Dec 12, 2019
Before this change, inner jitted functions wouldn't necessarily be fully
staged out into an outer-jit trace; instead, as much as possible would
be hoisted out of the inner jit. That led to extra constants getting
materialized in #1640.

For example:

```python
@jit
def f(x, y):
    z = 2 * x
    return y + z

@jit
def g(x):
    return f(2, x)

g(3)
```

would lead to these XLA computations being compiled and executed:

```
HloModule jit_f.7

ENTRY jit_f.7 {
  parameter.2 = () parameter(1)
  tuple.3 = () tuple()
  parameter.1 = s32[] parameter(0)
  constant.4 = s32[] constant(2)
  multiply.5 = s32[] multiply(parameter.1, constant.4)
  ROOT tuple.6 = ((), s32[]) tuple(tuple.3, multiply.5)
}

HloModule jit_g.14

jaxpr_subcomputation.4 {
  parameter.6 = () parameter(1)
  tuple.8 = () tuple()
  parameter.7 = s32[] parameter(2)
  parameter.5 = s32[] parameter(0)
  add.9 = s32[] add(parameter.7, parameter.5)
  ROOT tuple.10 = (s32[]) tuple(add.9)
}

ENTRY jit_g.14 {
  constant.1 = s32[] constant(4)
  tuple.3 = () tuple()
  parameter.2 = s32[] parameter(0)
  call.11 = (s32[]) call(constant.1, tuple.3, parameter.2), to_apply=jaxpr_subcomputation.4
  get-tuple-element.12 = s32[] get-tuple-element(call.11), index=0
  ROOT tuple.13 = (s32[]) tuple(get-tuple-element.12)
}
```

Notice that the `multiply` is separated out from the `add`, and in
particular the XLA computation underlying `g` only has the `add` in it.

This behavior was desirable when using partial evaluation for
reverse-mode autodiff, since in that case we want to partially evaluate
all the primal values underneath a call while staging out a jaxpr for
the tangent values. But it was undesirable for the other use of partial
evaluation, namely forming jaxprs under `jit` (and `pmap`).

The solution was just to tag jaxpr traces differently in the two cases.
mattjj added a commit that referenced this pull request Dec 31, 2019
mattjj added a commit that referenced this pull request Dec 31, 2019
@mattjj
Copy link
Collaborator

mattjj commented Dec 31, 2019

I believe the substantive issue here was actually fixed in #1848 (without changing any jax.nn code), but I only actually added these tests in #1930. Because of how these tests rely on make_jaxpr reflecting jit's behavior, I had to adapt it slightly too.

@mattjj mattjj closed this in #1930 Dec 31, 2019
mattjj added a commit that referenced this pull request Dec 31, 2019
@jheek jheek deleted the fix_large_constant_in_activations branch January 30, 2020 14:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants