Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax.{scan, map} break when supplied with an empty array #2412

Closed
sritchie opened this issue Mar 13, 2020 · 1 comment · Fixed by #2414
Closed

jax.lax.{scan, map} break when supplied with an empty array #2412

sritchie opened this issue Mar 13, 2020 · 1 comment · Fixed by #2414
Assignees
Labels
bug Something isn't working

Comments

@sritchie
Copy link
Contributor

Hey all, with 0.1.59 I'm seeing an error with both map and scan, where they can't handle an empty list:

jax.lax.map(lambda x: x * x, np.array([]))
jax.lax.scan(lambda i, j: (i, j), np.array([]), np.array([]))

Both result in this error:

TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (1,) for operand shape (0,).

This has been my workaround, though there is surely something more elegant and more correct:

def safe_map(f, xs):
  if xs.shape[0] == 0:
    return xs
  else:
    return jax.lax.map(f, xs)

Thank you!

@mattjj mattjj self-assigned this Mar 13, 2020
@mattjj mattjj added the bug Something isn't working label Mar 13, 2020
mattjj added a commit that referenced this issue Mar 13, 2020
When a fori_loop specialized on trip count is to be evaluated, it's
preferable to generate a scan rather than a while_loop because the
former is reverse-mode differentiable while the latter is not. Otherwise
they're essentially the same; in particular, no extensive inputs/outputs
arise unless reverse-mode autodiff is applied.

Also fixes #2412.
@mattjj
Copy link
Collaborator

mattjj commented Mar 13, 2020

Thanks for catching this! #2414 will squash it.

mattjj added a commit that referenced this issue Mar 13, 2020
When a fori_loop specialized on trip count is to be evaluated, it's
preferable to generate a scan rather than a while_loop because the
former is reverse-mode differentiable while the latter is not. Otherwise
they're essentially the same; in particular, no extensive inputs/outputs
arise unless reverse-mode autodiff is applied.

Also fixes #2412.
srvasude pushed a commit to srvasude/jax that referenced this issue May 5, 2020
When a fori_loop specialized on trip count is to be evaluated, it's
preferable to generate a scan rather than a while_loop because the
former is reverse-mode differentiable while the latter is not. Otherwise
they're essentially the same; in particular, no extensive inputs/outputs
arise unless reverse-mode autodiff is applied.

Also fixes jax-ml#2412.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants