diff --git a/examples/research_projects/jax-projects/README.md b/examples/research_projects/jax-projects/README.md index be8c6dc326bf..267e7c7078dc 100644 --- a/examples/research_projects/jax-projects/README.md +++ b/examples/research_projects/jax-projects/README.md @@ -621,7 +621,7 @@ state = model_flax.init(rng, dummy_input_ids) and then we can do the forward pass. ```python -sequences = model_flax.apply(input_ids, state) +sequences = model_flax.apply(state, input_ids) ``` Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object: