Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

While-loop vmap bug #3204

Closed
adabbott opened this issue May 25, 2020 · 3 comments · Fixed by #3207
Closed

While-loop vmap bug #3204

adabbott opened this issue May 25, 2020 · 3 comments · Fixed by #3207
Assignees
Labels
bug Something isn't working

Comments

@adabbott
Copy link

The following code block runs as expected on jax 0.1.63, jaxlib 0.1.45, but fails on all later versions, including master:

import jax
import jax.numpy as np
from jax.experimental import loops

def test(a,b):
    with loops.Scope() as s:
        s.val = 0
        s.i = 0
        s.j = 0
        for _ in s.while_range(lambda: s.i < a + 1):
            s.j = 0
            for _ in s.while_range(lambda: s.j < b + 1):
                s.val += s.i + s.j
                s.j += 1
            s.i += 1
        return s.val

# vectorized version
vmap_test = jax.vmap(test, (0,0))
arr = np.arange(5)
vmap_test(arr, arr)
Click for Traceback

Traceback (most recent call last):
  File "test.py", line 21, in <module>
    print(vmap_test(arr, arr))
  File "/home/adabbott/Git/jax/jax/jax/api.py", line 858, in batched_fun
    lambda: flatten_axes(out_tree(), out_axes))
  File "/home/adabbott/Git/jax/jax/jax/interpreters/batching.py", line 34, in batch
    return batched_fun.call_wrapped(*in_vals)
  File "/home/adabbott/Git/jax/jax/jax/linear_util.py", line 150, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "test.py", line 12, in test
    for _ in s.while_range(lambda: s.j < b + 1):
  File "/home/adabbott/Git/jax/jax/jax/experimental/loops.py", line 341, in __next__
    self.end_tracing_body()
  File "/home/adabbott/Git/jax/jax/jax/experimental/loops.py", line 407, in end_tracing_body
    carried_init_vals, body_typed_jaxpr, body_const_vals)
  File "/home/adabbott/Git/jax/jax/jax/experimental/loops.py", line 576, in build_output_vals
    body_jaxpr=body_typed_jaxpr)
  File "/home/adabbott/Git/jax/jax/jax/core.py", line 212, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/home/adabbott/Git/jax/jax/jax/interpreters/partial_eval.py", line 141, in process_primitive
    return custom_partial_eval_rules[primitive](self, *tracers, **params)
  File "/home/adabbott/Git/jax/jax/jax/lax/lax_control_flow.py", line 517, in _while_partial_eval
    body_jaxpr=body_jaxpr_known)
  File "/home/adabbott/Git/jax/jax/jax/core.py", line 212, in bind
    out_tracer = top_trace.process_primitive(self, tracers, kwargs)
  File "/home/adabbott/Git/jax/jax/jax/interpreters/batching.py", line 134, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)
  File "/home/adabbott/Git/jax/jax/jax/lax/lax_control_flow.py", line 391, in _while_loop_batching_rule
    body_nconsts=body_nconsts, body_jaxpr=body_jaxpr_batched)
  File "/home/adabbott/Git/jax/jax/jax/core.py", line 209, in bind
    return self.impl(*args, **kwargs)
  File "/home/adabbott/Git/jax/jax/jax/interpreters/xla.py", line 217, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
  File "/home/adabbott/Git/jax/jax/jax/interpreters/xla.py", line 248, in xla_primitive_callable
    *avals, **params)
  File "/home/adabbott/Git/jax/jax/jax/interpreters/xla.py", line 295, in primitive_computation
    *xla_args, **params)
  File "/home/adabbott/Git/jax/jax/jax/lax/lax_control_flow.py", line 332, in _while_loop_translation_rule
    new_z = _map(partial(_pred_bcast_select, body_c, body_pred), new_z, z)
  File "/home/adabbott/Git/jax/jax/jax/util.py", line 34, in safe_map
    return list(map(f, *args))
  File "/home/adabbott/Git/jax/jax/jax/lax/lax_control_flow.py", line 350, in _pred_bcast_select
    assert pred_shape == x_shape[:len(pred_shape)] == y_shape[:len(pred_shape)]
AssertionError

It appears to only occur when the nested while-loop variable b is vectorized:

# this works
vmap_test = jax.vmap(test, (0,None))
vmap_test(arr, 3)

# this fails
vmap_test = jax.vmap(test, (None,0))
vmap_test(3, arr)
@adabbott
Copy link
Author

I should also note the same behavior occurs when using jax.lax.while_loop directly, without the convenience of the loops module:

Click for pure jax.lax.while_loop version

import jax
import jax.numpy as np

def test(a,b):
    val = 0
    i = 0
    j = 0

    condfun_1 = lambda inp: inp[1] < a + 1 
    condfun_2 = lambda inp: inp[2] < b + 1 

    def bodyfun_1(inp):
        val, i, j = inp
        j = 0
        def bodyfun_2(inp):
            val, i, j = inp
            val += i + j
            j += 1
            return (val, i, j)
        result = jax.lax.while_loop(condfun_2, bodyfun_2, (val,i,j))
        val = result[0]
        i += 1
        return (val, i, j)
    result = jax.lax.while_loop(condfun_1, bodyfun_1, (val,i,j))
    return result[0]

arr = np.arange(5)
vmap_test = jax.vmap(test, (0,0))
vmap_test(arr, arr)

@gnecula gnecula self-assigned this May 26, 2020
@gnecula gnecula added the bug Something isn't working label May 26, 2020
@gnecula
Copy link
Collaborator

gnecula commented May 26, 2020

This is triggered in the context of the partial evaluation of while, which was added in #2497. I feel that adding the partial eval for while was a mistake, since its only purpose was to improve an error message, but it now adds real code generation, which in this case it fails.

The failure, however, is a bug in the code generation. Fixed in #3207. Please re-open if you think this does not fixes the problem.

@gnecula gnecula closed this as completed May 26, 2020
@adabbott
Copy link
Author

Thanks for the quick fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants