Skip to content

Commit

Permalink
add randomness and return_state_dict instructions to README
Browse files Browse the repository at this point in the history
  • Loading branch information
samuela committed Jan 4, 2025
1 parent caea0ac commit 6faa450
Showing 1 changed file with 38 additions and 10 deletions.
48 changes: 38 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def f(x):
x.mul_(2)
return x

f(torch.Tensor([3])) # => torch.Tensor([8])
f(torch.tensor([3])) # => torch.Tensor([8])

jax_f = t2j(f)
jax_f(jnp.array([3])) # => jnp.array([8])
Expand Down Expand Up @@ -69,17 +69,47 @@ torch2jax is an implementation of the PyTorch standard library written in JAX. I

Adding new PyTorch operations is straightforward. Check the source for functions decorated with `@implements` to get started.

### My PyTorch model includes dropout (or some other random operation), and does not work in training mode. Why?
### My PyTorch model includes dropout or some other random operation. How does this work with torch2jax?

JAX mandates [deterministic randomness](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html), while PyTorch does not. This leads to some API friction. torch2jax does not currently offer a means to bridge this gap. I have an idea for how to accomplish it. If this is important to you, please open an issue.
Pass a `jax.random.PRNGKey` to the converted function:

In the meantime, make sure to call `.eval()` on your `torch.nn.Module` before conversion.
```python
t2j(lambda: torch.randn(3))(rng=jax.random.PRNGKey(123))
# => [-0.56996626, -0.6440589 , 0.28660855]

t2j(lambda: torch.randn(3))(rng=jax.random.PRNGKey(456))
# => [-1.3227656, -1.4896724, -2.5057693]
```

After conversion, random state will be handled entirely in JAX. `torch.manual_seed` and its ilk will have no effect on the converted function.

If you only care about running a model and not training it, you can call `.eval()` on it to avoid the randomness issue altogether, at least for most common random operations like dropout:

```python
rn18 = torchvision.models.resnet18().eval()
t2j(rn18)(t2j(torch.randn(1, 3, 224, 224))) # Look ma, no `rng` kwarg!
```

### My PyTorch model includes batch norm (or some other `torch.nn.Module` utilizing buffers), and does not work in training mode. What can I do?
> [!NOTE]
> Non-deterministic behavior is, well, non-deterministic. You will not see the same results with the same random seed when switching between PyTorch and JAX. However, the sampling process _will_ be equivalent.
Similar to the randomness story, PyTorch and JAX have different approaches to maintaining state. Operations like batch norm require maintaining running statistics. In PyTorch, this is accomplished via [buffers](https://stackoverflow.com/questions/57540745/what-is-the-difference-between-register-parameter-and-register-buffer-in-pytorch/57546078#57546078).
### My PyTorch model includes batch norm or some other `torch.nn.Module` that mutates buffers. How does this work with torch2jax?

torch2jax supports running batch norm models in `eval()`-mode. Just don't forget that you should avoid taking gradients w.r.t. buffers. For example,
Some PyTorch modules like `torch.nn.BatchNorm1d` mutate internal state in the form of [buffers](https://stackoverflow.com/questions/57540745/what-is-the-difference-between-register-parameter-and-register-buffer-in-pytorch/57546078#57546078).

torch2jax supports this with the optional `return_state_dict` argument:

```python
rn18 = torchvision.models.resnet18()
batch = torch.randn(1, 3, 224, 224)

before_state_dict = {k: t2j(v) for k, v in rn18.state_dict().items()}
out, after_state_dict = t2j(rn18)(t2j(batch), state_dict=before_state_dict, return_state_dict=True)
```

As with randomness, if you only care about running a model and not training it, you can call `.eval()` on it to avoid buffer issues altogether in most cases.

Also, don't forget to avoid taking gradients w.r.t. buffers. For example,

```python
rn18 = torchvision.models.resnet18().eval()
Expand All @@ -95,13 +125,11 @@ jax_rn18 = t2j(rn18)
grad(lambda params, x: loss(jax_rn18(x, state_dict={**params, **buffers})))(parameters, t2j(batch))
```

I have an idea for how to implement buffers, including in training mode. If this is important to you, please open an issue.

### I'm seeing slightly different numerical results between PyTorch and JAX. Is it a bug?

Floating point arithmetic is hard. There are a number of sources of divergence preventing bit-for-bit equivalence:

1. torch2jax guarantees equivalence with PyTorch standard library functions in the mathematical sense, but not necessarily in their operational execution. This can lead to slight differences in results.
1. torch2jax guarantees equivalence with PyTorch standard library functions in the mathematical sense, but not necessarily in their operational execution. This can lead to slight differences in results. For example, the multi-head attention implementations calculate the same mathematical function, but may vary in execution details such as the order of operations, the use of fused kernels, and so forth.
2. The JAX/XLA and PyTorch compilers apply different optimizations and should be expected to rewrite computation graphs in exciting and unpredictable ways, potentially invoking different CUDA kernels.
3. CUDA kernels can be non-deterministic, for example as a result of floating point addition being non-associative.

Expand Down

0 comments on commit 6faa450

Please sign in to comment.