You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’m not sure if this issue is specific to flax_random_module or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramatic
Code:
importtimeimportnumpyasnpimportjaximportjax.numpyasjnpimportnumpyroimportnumpyro.distributionsasdistfromnumpyro.inferimportNUTS, MCMCfromnumpyro.contrib.moduleimportrandom_flax_moduleimportflax.linenasnn# Set a random seed for reproducibilityrng_key=jax.random.PRNGKey(0)
# Generate some dummy datadefgenerate_data(n=100, noise_std=0.1):
X=jnp.linspace(-1, 1, n)
y=3*X+2+np.random.normal(0, noise_std, size=X.shape)
returnX[:, None], y# Define a simple neural networkclassSimpleNN(nn.Module):
@nn.compactdef__call__(self, x):
x=nn.Dense(10)(x)
x=nn.relu(x)
x=nn.Dense(1)(x)
returnx.squeeze()
# Define the modeldefmodel(X, y):
module=SimpleNN()
nn=random_flax_module("nn", module, input_shape=(1, X.shape[-1]), prior=dist.Normal(0, 1))
withnumpyro.plate("data", X.shape[0]):
mean=nn(X)
numpyro.sample("obs", dist.Normal(mean, 0.1), obs=y)
# Generate dataX, y=generate_data()
# Initialize the NUTS samplernuts_kernel=NUTS(model)
# Run inferencenum_warmup, num_samples=500, 1000start_time=time.time()
mcmc=MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(rng_key, X, y)
end_time=time.time()
# Print runtimeprint(f"Runtime: {end_time-start_time:.2f} seconds")
# Print summary statisticsprint(mcmc.print_summary())
The text was updated successfully, but these errors were encountered:
I had similar issues with SVI inference when I switched from JAX 0.4.30 to 0.4.33. The suggested workaround on jax-ml/jax#23822 (i.e., setting the env variable XLA_FLAGS=--xla_cpu_use_thunk_runtime=false) seemed to return my runtimes back to what it was with JAX 0.4.30.
Jax-0.4.31: Runtime: 27.06 seconds
https://colab.research.google.com/drive/1EsFY1St8Y2ZNBZ9UXTa9FDWrjPDdTU4U?usp=sharing
Jax-0.4.33: Runtime: 84.91 seconds
https://colab.research.google.com/drive/1g7GkuK4-GloO6cywvDUf5BVU9qO2jf1W?usp=sharing
I’m not sure if this issue is specific to
flax_random_module
or a broader problem, but I’ve primarily been using NumPyro for HMC BNNs, and the difference in speed with the latest JAX release is quite dramaticCode:
The text was updated successfully, but these errors were encountered: