-
-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax 0.2 does not support jax.numpy.reshape with non-constant values in omnistaging mode #43
Comments
I just pushed a fix, and—with that—this model appears to work when |
Looks like the problem could be related to a JAX + fork multiprocessing issue. |
I tried passing a In [5]: with model:
trace = pm.sample(1000, chains=2, mp_ctx=ctx)
Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (2 chains in 4 jobs)
NUTS: [sigma, beta]
INFO:pymc3:NUTS: [sigma, beta]
~/projects/code/python/Theano/theano/sandbox/jax.py:198: UserWarning: `jnp.copy` is not implemented yet. Using the object's `copy` method.
warn("`jnp.copy` is not implemented yet. " "Using the object's `copy` method.")
~/projects/code/python/Theano/theano/sandbox/jax.py:202: UserWarning: Object has no `copy` method: Traced<ShapedArray(float64[2]):JaxprTrace(level=-1/1)>
warn("Object has no `copy` method: {}".format(x))
~/projects/code/python/Theano/theano/sandbox/jax.py:202: UserWarning: Object has no `copy` method: Traced<ShapedArray(float64[]):JaxprTrace(level=-1/1)>
warn("Object has no `copy` method: {}".format(x))
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "~/projects/code/python/pymc3/pymc3/parallel_sampling.py", line 114, in _unpickle_step_method
self._step_method = pickle.loads(self._step_method)
File "~/projects/code/python/Theano/theano/compile/mode.py", line 305, in __setstate__
linker = predefined_linkers[linker]
KeyError: 'jax'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "~/projects/code/python/pymc3/pymc3/parallel_sampling.py", line 135, in run
self._unpickle_step_method()
File "~/projects/code/python/pymc3/pymc3/parallel_sampling.py", line 116, in _unpickle_step_method
raise ValueError(unpickle_error)
ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver.
"""
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
ValueError: The model could not be unpickled. This is required for sampling with more than one core and multiprocessing context spawn or forkserver. The pickling error arises due to the absence of a I'll try adding the |
Tried again with the updated branch and cores=1, chains=1 and now getting:
|
Looks like the same error. Here's the exact code I'm using on e4043ce0b and its output: import pymc3 as pm
import theano
import numpy as np
import theano.sandbox.jax
theano.compile.mode.predefined_linkers["jax"] = theano.sandbox.jax.JaxLinker()
jax_mode = theano.compile.Mode(linker="jax")
x = np.linspace(0, 1, 10)
y = x * 4. + 1.4 + np.random.randn(10)
with pm.Model() as model:
beta = pm.Normal("beta", 0., 5., shape=2)
sigma = pm.HalfNormal("sigma", 2.5)
obs = pm.Normal("obs", beta[0] + beta[1] * x, sigma, observed=y)
pm.sample(mode=jax_mode, chains=1, cores=1) Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/home/bwillard/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Sequential sampling (1 chains in 1 job)
INFO:pymc3:Sequential sampling (1 chains in 1 job)
NUTS: [sigma, beta]
INFO:pymc3:NUTS: [sigma, beta]
Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 22 seconds.
INFO:pymc3:Sampling 1 chain for 1_000 tune and 1_000 draw iterations (1_000 + 1_000 draws total) took 22 seconds.
There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There was 1 divergence after tuning. Increase `target_accept` or reparameterize.
Only one chain was sampled, this makes it impossible to run some convergence checks
INFO:pymc3:Only one chain was sampled, this makes it impossible to run some convergence checks
Out[2]: <MultiTrace: 1 chains, 1000 iterations, 3 variables> Perhaps it's a difference in the jax 0.1.75
jaxlib 0.1.52 |
|
Ah, I'll try those versions. Otherwise, I just pushed a update that puts the JAX import multiprocessing as mp
import theano
import numpy as np
import pymc3 as pm
ctx = mp.get_context('spawn')
jax_mode = theano.compile.Mode(linker="jax")
x = np.linspace(0, 1, 10)
y = x * 4. + 1.4 + np.random.randn(10)
with pm.Model() as model:
beta = pm.Normal("beta", 0., 5., shape=2)
sigma = pm.HalfNormal("sigma", 2.5)
obs = pm.Normal("obs", beta[0] + beta[1] * x, sigma, observed=y)
with model:
pm.sample(mode=jax_mode, mp_ctx=ctx) Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Multiprocess sampling (4 chains in 4 jobs)
INFO:pymc3:Multiprocess sampling (4 chains in 4 jobs)
NUTS: [sigma, beta]
INFO:pymc3:NUTS: [sigma, beta]
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
~/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 67 seconds.
INFO:pymc3:Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 67 seconds.
There were 16 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 16 divergences after tuning. Increase `target_accept` or reparameterize.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
ERROR:pymc3:There were 2 divergences after tuning. Increase `target_accept` or reparameterize.
The number of effective samples is smaller than 25% for some parameters.
INFO:pymc3:The number of effective samples is smaller than 25% for some parameters.
Out[3]: <MultiTrace: 4 chains, 1000 iterations, 3 variables> |
That's awesome! |
OK, I got the same error as you using |
For anyone who's interested and more familiar with JAX, here's a MWE of the problem under import numpy as np
import jax.numpy as jnp
x = np.zeros((2 * 3))
z = (2, 3)
expected_res = np.reshape(x, np.array(z, dtype=np.int))
def b(z):
return jnp.array(z, dtype=np.int)
def a(x, z):
return jnp.reshape(x, b(z))
jax_res = jax.jit(a)(x, z) ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function a at <ipython-input-26-182028b110f7>:10, this value became a tracer due to JAX operations on these lines:
operation d:int64[1] = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] b:int64[]
from line <ipython-input-26-182028b110f7>:8 (b)
operation e:int64[1] = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] c:int64[]
from line <ipython-input-26-182028b110f7>:8 (b)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> |
jax issue?
…On Sat, Sep 26, 2020, 07:44 Brandon T. Willard ***@***.***> wrote:
For anyone who's interested and more familiar with JAX, here's a MWE of
the problem under jax 0.2.0 and jaxlib 0.1.55:
import numpy as np
import jax.numpy as jnp
x = np.zeros((2 * 3))z = (2, 3)
expected_res = np.reshape(x, np.array(z, dtype=np.int))
def b(z):
return jnp.array(z, dtype=np.int)
def a(x, z):
return jnp.reshape(x, b(z))
jax_res = jax.jit(a)(x, z)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function a at <ipython-input-26-182028b110f7>:10, this value became a tracer due to JAX operations on these lines:
operation d:int64[1] = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] b:int64[]
from line <ipython-input-26-182028b110f7>:8 (b)
operation e:int64[1] = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] c:int64[]
from line <ipython-input-26-182028b110f7>:8 (b)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#43 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGB2XO6GVQRU5B3CFKLSHV5UHANCNFSM4R2LRSHA>
.
|
If I understand correctly, Jax does not support jitting reshape as the shape is dynamic. |
Oh so they use to support it, hmm maybe marking the shape as static will help? ie this works
|
The actual graph is likely much more complex, and I don't know if there's an argument in the JITed function that's actually static and maps directly to the More importantly, why isn't it supported anymore? |
I think it is related to a change call omnistaging: jax-ml/jax#3370 |
TFP also have to made quite a bit of changes due to omnistaging in Jax: https://github.com/tensorflow/probability/search?q=omnistaging&type=commits |
Downgrading I now can sample.
…On Sat, Sep 26, 2020 at 8:30 AM Junpeng Lao ***@***.***> wrote:
I think it is related to a change call omnistaging: jax-ml/jax#3370
<jax-ml/jax#3370>
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#43 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGFYKUYSKEC7ANKMNSLSHWC7ZANCNFSM4R2LRSHA>
.
|
For reshape specifically: tensorflow/probability@782d0c6 |
Here's the sub-graph causing the problem: from theano.printing import debugprint as tt_dprint
dlogp_fn = model.logp_dlogp_function(mode=jax_mode)
dlogp_fgraph = dlogp_fn._theano_function.maker.fgraph
tt_dprint(dlogp_fgraph.outputs[1].owner.inputs[1]) Reshape{1} [id A] ''
|Elemwise{Composite{(Switch(i0, (i1 * i2 * i2), i3) + i4 + (i5 * (((i6 * i7 * Composite{inv(Composite{(sqr(i0) * i0)}(i0))}(i2)) / i8) - (i9 * Composite{inv(Composite{(sqr(i0) * i0)}(i0))}(i2))) * i2))}}[(0, 7)] [id B] '(d__logp/dsigma_log__)'
| |Elemwise{Composite{Cast{int8}(GE(i0, i1))}} [id C] ''
| | |Elemwise{exp,no_inplace} [id D] 'sigma'
| | | ...
|MakeVector{dtype='int64'} [id BL] ''
|Elemwise{Composite{(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i0, i1), i2) - Switch(LT(Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i3, i1), i2), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i0, i1), i2)), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i3, i1), i2), Composite{Switch(LT(i0, i1), i1, i0)}(Composite{Switch(GE(i0, i1), i1, i0)}(i0, i1), i2)))}}[(0, 1)] [id BM] ''
|TensorConstant{3} [id BN]
|Shape_i{0} [id BO] ''
| |__args_joined [id G]
|TensorConstant{0} [id K]
|TensorConstant{2} [id BP] The graph tells us that the This means that a more representative MWE would probably be as follows: def d(y, z):
return jnp.shape(y)[z]
def c(y, z):
return jnp.where(d(y, 0) / z > 0, d(y, 0) / z, 0)
def b(y):
return jnp.array([c(y, 2), c(y, 3)], dtype=np.int)
def a(y):
return jnp.reshape(y, b(y))
jax_res = jax.jit(a)(x) Unfortunately, I don't think we can use |
I just pushed a commit that disables omnistaging by default. That should allow I also found this example, which is essentially our problem, and it says that the solution is to use NumPy to compute the shape. @junpenglao is there a straightforward way to force a NumPy computation of those |
It works for me with that work-around now 👍 . |
Let's keep this opened as we should aim to fix the omnistaging issue. |
@junpenglao, does the solution to this particular issue involve |
I think there are a few level of fix we could think about:
In the short term, I need to understand better how the reshape is done in theano - IIUC, when def d(y, z):
return np.shape(y)[z]
def c(y, z):
return np.where(d(y, 0) / z > 0, d(y, 0) / z, 0)
def b(y):
return [c(y, 2), c(y, 3)]
def a(y):
return jnp.reshape(y, b(y))
jax_res = jax.jit(a)(x) |
We can't do this for all Theano graphs; that would remove a wide array of valuable Theano capabilities! |
I see - could we add a static reshape mode for pymc3 instead? jax does not really support dynamic shape anyway so if we want the jax backend to do dynamic graph stuff it would be pretty difficult. |
It seems like we should dig into the lower-level aspects of This seems like the correct approach if only because it might also provide solutions to similar symbolic limitations we've encountered (e.g. #68). |
We could consider linking to XLA directly, not sure how much work that
would be though.
…On Tue, Oct 13, 2020, 23:52 Brandon T. Willard ***@***.***> wrote:
It seems like we should dig into the lower-level aspects of jax and see
if we can take a more direct approach from our end. One that doesn't go
through these omnistaging changes, for instance.
This seems like the correct approach if only because it might also provide
solutions to similar symbolic limitations we've encountered (e.g. #68
<#68>).
—
You are receiving this because you modified the open/close state.
Reply to this email directly, view it on GitHub
<#43 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGCEFZFX4WPOXSNZZJLSKTDZRANCNFSM4R2LRSHA>
.
|
Yeah, that's what I was thinking. |
From a jax post explaining omnistaging it looks like using |
but that would incur Python call overhead, no? I think that's just a poor workaround. |
The last I recall from reading and experimenting with all that is that they're actually saying such graphs are no longer possible, so do everything involving We can handle that by (re)allowing a mix of Python and JAX thunks, of course, but that's much less ideal than a single JAX compiled/JITed function, especially when it comes to interactions with other JAX code (e.g. JAX-based sampler functions). |
Can numba do that?
…On Sat, Apr 3, 2021 at 8:53 PM Brandon T. Willard ***@***.***> wrote:
From a jax post explaining omnistaging
<https://github.com/google/jax/blob/master/design_notes/omnistaging.md#solution>
it looks like using numpy.reshape instead of jax.numpy.reshape is the
correct way to solve this reshape problem.
but that would incur Python call overhead, no? I think that's just a poor
workaround.
The last I recall from reading and experimenting with all that is that
they're actually saying such graphs are no longer possible, so do
everything involving np.reshape before using JAX.
We can handle that by (re)allowing a mix of Python and JAX thunks, of
course, but that's much less ideal than a single JAX compiled/JITed
function, especially when it comes to use with other JAX code (e.g.
JAX-based sampler functions).
—
You are receiving this because you modified the open/close state.
Reply to this email directly, view it on GitHub
<pymc-devs/pytensor#43 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFETGHOFEBKP4UW44LBAR3TG5P3BANCNFSM4R2LRSHA>
.
|
Can it compile a function that uses Here's an example: import numpy as np
import numba
@numba.njit
def testfn(x, shape):
return np.reshape(x, shape) >>> testfn(np.ones(10), (5, 2))
array([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]]) As I understand it, JAX isn't meant to be an all-purpose JITer; it's constrained by its connections to XLA, a specific domain of work/relevance, etc. |
Traceback:
The text was updated successfully, but these errors were encountered: