Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type specification/inferrence now fails in @jax.jit #9380

Closed
3 tasks done
sjperkins opened this issue Jan 31, 2022 · 5 comments · Fixed by #9391
Closed
3 tasks done

Type specification/inferrence now fails in @jax.jit #9380

sjperkins opened this issue Jan 31, 2022 · 5 comments · Fixed by #9391
Assignees
Labels
bug Something isn't working

Comments

@sjperkins
Copy link

sjperkins commented Jan 31, 2022

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:

The following test case has worked for a while, but has recently started failing. It seems that there isn't a constant handler for complex64 anymore? The following reproducer demonstrates the issue, but replacing out_dtype.type(1j) with 1j fixes the problem.

This works on jax 0.2.26 and jaxlib 0.1.75, but fails on jax 0.2.27 and 0.1.76.

import numpy as np
import jax
import jax.numpy as jnp

@jax.jit
def phase_delay(lm, uvw, frequency):
    out_dtype = jnp.result_type(lm, uvw, frequency, np.complex64)

    one = lm.dtype.type(1.0)
    neg_two_pi_over_c = lm.dtype.type(-2*np.pi/3e8)

    l = lm[:, 0, None, None]  # noqa
    m = lm[:, 1, None, None]

    u = uvw[None, :, 0, None]
    v = uvw[None, :, 1, None]
    w = uvw[None, :, 2, None]

    n = jnp.sqrt(one - l**2 - m**2) - one

    real_phase = (neg_two_pi_over_c *
                  (l * u + m * v + n * w) *
                  frequency[None, None, :])
 
    # replacing out_dtype.type(1j) with 1j fixes this problem
    return jnp.exp(out_dtype.type(1j)*real_phase)

if __name__ == "__main__":
    uvw = np.random.random(size=(100, 3)).astype(np.float32)
    lm = np.random.random(size=(10, 2)).astype(np.float32)*0.001
    frequency = np.linspace(.856e9, .856e9*2, 64).astype(np.float32)

    complex_phase = phase_delay(lm, uvw, frequency)
$ python test_complex_constant_fail.py 
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "test_complex_constant_fail.py", line 32, in <module>
    complex_phase = phase_delay(lm, uvw, frequency)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/api.py", line 429, in cache_miss
    donated_invars=donated_invars, inline=inline)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/core.py", line 1671, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/core.py", line 1683, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/core.py", line 596, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/dispatch.py", line 143, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/dispatch.py", line 170, in _xla_callable_uncached
    *arg_specs).compile().unsafe_call
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/dispatch.py", line 260, in lower_xla_callable
    donated_invars)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/interpreters/mlir.py", line 403, in lower_jaxpr_to_module
    input_output_aliases=input_output_aliases)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/interpreters/mlir.py", line 541, in lower_jaxpr_to_fun
    *args)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/interpreters/mlir.py", line 606, in jaxpr_subcomp
    in_nodes = map(read, eqn.invars)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/_src/util.py", line 44, in safe_map
    return list(map(f, *args))
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/interpreters/mlir.py", line 583, in read
    return ir_constants(v.val, canonicalize_types=True)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/interpreters/mlir.py", line 171, in ir_constants
    raise TypeError("No constant handler for type: {}".format(type(val)))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: No constant handler for type: <class 'numpy.complex64'>

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test_complex_constant_fail.py", line 32, in <module>
    complex_phase = phase_delay(lm, uvw, frequency)
  File "/home/sperkins/venv/afr/lib/python3.7/site-packages/jax/interpreters/mlir.py", line 171, in ir_constants
    raise TypeError("No constant handler for type: {}".format(type(val)))
TypeError: No constant handler for type: <class 'numpy.complex64'>
@sjperkins sjperkins added the bug Something isn't working label Jan 31, 2022
@sjperkins sjperkins reopened this Jan 31, 2022
sjperkins added a commit to ratt-ru/codex-africanus that referenced this issue Jan 31, 2022
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 31, 2022

Thanks for the report. I can reproduce with jaxlib 0.1.76; the jax version doesn't appear to matter.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 31, 2022

Shorter repro:

import numpy as np
import jax
import jax.numpy as jnp


@jax.jit
def f():
  return jnp.exp(np.complex64(1j))
f()

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 31, 2022

I think the issue is that complex dtypes were left out of the list here: https://github.com/google/jax/blob/0382a6a04eddd7506a4ef6bb0c93f0f660ee3df6/jax/interpreters/mlir.py#L243-L247

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 1, 2022

We're going to make a jax release very shortly incorporating this fix.

@sjperkins
Copy link
Author

We're going to make a jax release very shortly incorporating this fix.

Thanks for the quick turnaround!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants