Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NUTS performance concerns on GPU #597

Open
PaulScemama opened this issue Dec 1, 2023 · 4 comments
Open

NUTS performance concerns on GPU #597

PaulScemama opened this issue Dec 1, 2023 · 4 comments

Comments

@PaulScemama
Copy link
Contributor

Describe the issue as clearly as possible:

On a trivial example (that of quickstart.md) there appears to be a weird bug I'm experiencing with the NUTS sampler using a GPU.

When I run the script (which I copy below) with a GPU for 200 steps I get

Jax sees these devices: [gpu(id=0)]
Starting to run nuts for 200 steps
Nuts took 0.050431712468465166 minutes

When I run the script with a GPU for 300 steps I get

Jax sees these devices: [gpu(id=0)]
Starting to run nuts for 300 steps
Nuts took 0.8048396507898966 minutes

When I run the script with GPU for 500 steps I get

Jax sees these devices: [gpu(id=0)]
Starting to run nuts for 500 steps
Nuts took 1.2937044938405355 minutes

When I run the script on CPU with 1000 steps I get

Jax sees these devices: [CpuDevice(id=0)]
Starting to run nuts for 1000 steps
Nuts took 0.06121724049250285 minutes

Steps/code to reproduce the bug:

import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats

import blackjax


from datetime import date

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))


loc, scale = 10, 20
observed = np.random.normal(loc, scale, size=1_000)


def logdensity_fn(loc, log_scale, observed=observed):
    """Univariate Normal"""
    scale = jnp.exp(log_scale)
    logpdf = stats.norm.logpdf(observed, loc, scale)
    return jnp.sum(logpdf)


logdensity = lambda x: logdensity_fn(**x)


def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states


inv_mass_matrix = np.array([0.5, 0.01])
step_size = 1e-3

nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)

initial_position = {"loc": 1.0, "log_scale": 1.0}
initial_state = nuts.init(initial_position)
initial_state

rng_key, sample_key = jax.random.split(rng_key)

# TIMING NUTS
import time

start = time.time()
num_steps = 500
print(f"Jax sees these devices: {jax.devices()}")
print(f"Starting to run nuts for {num_steps} steps")
states = inference_loop(sample_key, nuts.step, initial_state, num_steps)
end = time.time()
print(f"Nuts took {(end-start)/60} minutes")

Expected result:

A shorter amount of time to run. I am not super familiar with the CPU/GPU benefits/pitfalls for MCMC sampling like NUTS. Maybe CPU is much faster? If so, I think a warning would be nice; consider the scenario where a user is using a GPU for things like variational inference then decides to use NUTS and it takes forever.

Error message:

No response

Blackjax/JAX/jaxlib/Python version information:

BlackJAX 0.1.dev454+g164a4dd
Python 3.9.17 (main, Nov 28 2023, 23:51:11) 
[GCC 7.5.0]
Jax 0.4.16
Jaxlib 0.4.16

Context for the issue:

No response

@junpenglao
Copy link
Member

This is not my experience, what is your environment?

Also, important note re benchmarking JAX: https://jax.readthedocs.io/en/latest/async_dispatch.html

@PaulScemama
Copy link
Contributor Author

@junpenglao I will check back on this later this weekend -- possible that it is an environment problem. I'll get back to you then.

@DanWaxman
Copy link

This is replicable on Colab, so I don't think it's an environment issue.

Output for CPU:

Jax sees these devices: [CpuDevice(id=0)]
Starting to run nuts for 500 steps
NUTS Call took 0.12702747980753581 minutes

Output for GPU:

Jax sees these devices: [cuda(id=0)]
Starting to run nuts for 500 steps
NUTS Call took 0.7922836542129517 minutes

I think this is more or less expected behavior though when the problem is rather small, and doesn't include operations GPUs are particularly good at. There was a similar discussion for NumPyro here, with the takeaway being that Jax is particularly efficient on CPU and GPU acceleration only makes sense for certain problems.

@gil2rok
Copy link
Contributor

gil2rok commented Jun 21, 2024

Note that NUTS is control-flow heavy which makes its hard to run fast on a GPU.

See the CHEES algorithm, implemented in BlackJax, for a NUTS-like sampler that avoids this problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants