From 230a536d4a47741ede8f41ff662bceb92e67e401 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 7 Sep 2023 14:41:37 -0700 Subject: [PATCH] Migrate flax from using old-style PRNG keys to new-style typed PRNG keys Functionally, this involves changing uses of jax.random.PRNGKey to jax.random.key. For details on this change and the motivation behind it, see the draft JEP at https://github.com/google/jax/pull/17297, and please feel free to offer comments and feedback! PiperOrigin-RevId: 563549594 --- CHANGELOG.md | 6 +- README.md | 6 +- docs/developer_notes/lift.md | 8 +- docs/developer_notes/module_lifecycle.rst | 12 +- docs/flip/1009-optimizer-api.md | 2 +- docs/flip/2396-rnn.md | 4 +- docs/getting_started.ipynb | 4 +- docs/getting_started.md | 4 +- docs/guides/batch_norm.rst | 4 +- docs/guides/convert_pytorch_to_flax.rst | 10 +- docs/guides/dropout.rst | 6 +- docs/guides/ensembling.rst | 4 +- docs/guides/extracting_intermediates.rst | 18 +- docs/guides/flax_basics.ipynb | 14 +- docs/guides/flax_basics.md | 14 +- docs/guides/flax_on_pjit.ipynb | 2 +- docs/guides/flax_on_pjit.md | 2 +- docs/guides/haiku_migration_guide.rst | 30 +-- docs/guides/jax_for_the_impatient.ipynb | 6 +- docs/guides/jax_for_the_impatient.md | 6 +- docs/guides/lr_schedule.rst | 2 +- docs/guides/model_surgery.ipynb | 2 +- docs/guides/model_surgery.md | 2 +- docs/guides/optax_update_guide.rst | 2 +- docs/guides/regular_dict_upgrade_guide.rst | 6 +- docs/guides/rnncell_upgrade_guide.rst | 18 +- docs/guides/state_params.rst | 6 +- docs/guides/transfer_learning.ipynb | 2 +- docs/guides/transfer_learning.md | 2 +- docs/guides/use_checkpointing.ipynb | 2 +- docs/guides/use_checkpointing.md | 2 +- docs/index.rst | 4 +- docs/notebooks/flax_sharp_bits.ipynb | 4 +- docs/notebooks/flax_sharp_bits.md | 4 +- docs/notebooks/linen_intro.ipynb | 24 +-- docs/notebooks/linen_intro.md | 24 +-- docs/notebooks/optax_update_guide.ipynb | 2 +- docs/notebooks/optax_update_guide.md | 2 +- docs/notebooks/state_params.ipynb | 12 +- docs/notebooks/state_params.md | 12 +- examples/imagenet/imagenet.ipynb | 2 +- examples/imagenet/models_test.py | 4 +- examples/imagenet/train.py | 4 +- examples/imagenet/train_test.py | 8 +- .../linen_design_test/attention_simple.py | 2 +- examples/linen_design_test/autoencoder.py | 2 +- .../linen_design_test/linear_regression.py | 2 +- examples/linen_design_test/mlp_explicit.py | 2 +- examples/linen_design_test/mlp_inline.py | 2 +- examples/linen_design_test/mlp_lazy.py | 2 +- .../linen_design_test/tied_autoencoder.py | 2 +- examples/linen_design_test/weight_std.py | 2 +- examples/lm1b/temperature_sampler_test.py | 2 +- examples/lm1b/train.py | 4 +- examples/mnist/train.py | 2 +- examples/mnist/train_test.py | 2 +- examples/nlp_seq/train.py | 2 +- examples/ogbg_molpcba/models_test.py | 4 +- examples/ogbg_molpcba/train.py | 2 +- examples/ogbg_molpcba/train_test.py | 2 +- examples/ppo/ppo_lib.py | 2 +- examples/ppo/ppo_lib_test.py | 4 +- examples/seq2seq/seq2seq.ipynb | 2 +- examples/seq2seq/train.py | 2 +- examples/seq2seq/train_test.py | 6 +- examples/sst2/models.py | 2 +- examples/sst2/models_test.py | 8 +- examples/sst2/train.py | 2 +- examples/sst2/train_test.py | 2 +- examples/vae/train.py | 2 +- examples/wmt/train.py | 4 +- flax/core/flax_functional_engine.ipynb | 10 +- flax/core/meta.py | 4 +- flax/core/scope.py | 2 +- flax/cursor.py | 4 +- flax/errors.py | 30 +-- flax/linen/initializers.py | 4 +- flax/linen/module.py | 22 +-- flax/linen/normalization.py | 4 +- flax/linen/recurrent.py | 12 +- flax/linen/summary.py | 2 +- flax/linen/transforms.py | 18 +- pyproject.toml | 2 +- tests/core/core_lift_test.py | 20 +- tests/core/core_meta_test.py | 10 +- tests/core/core_scope_test.py | 20 +- tests/core/design/core_attention_test.py | 2 +- tests/core/design/core_auto_encoder_test.py | 6 +- tests/core/design/core_big_resnets_test.py | 4 +- tests/core/design/core_custom_vjp_test.py | 4 +- tests/core/design/core_dense_test.py | 8 +- tests/core/design/core_flow_test.py | 2 +- tests/core/design/core_resnet_test.py | 4 +- tests/core/design/core_scan_test.py | 8 +- .../core/design/core_tied_autoencoder_test.py | 4 +- tests/core/design/core_vmap_test.py | 8 +- tests/core/design/core_weight_std_test.py | 4 +- tests/cursor_test.py | 4 +- tests/linen/initializers_test.py | 4 +- tests/linen/linen_activation_test.py | 2 +- tests/linen/linen_attention_test.py | 14 +- tests/linen/linen_combinators_test.py | 12 +- tests/linen/linen_linear_test.py | 90 ++++----- tests/linen/linen_meta_test.py | 17 +- tests/linen/linen_module_test.py | 186 +++++++++--------- tests/linen/linen_recurrent_test.py | 70 +++---- tests/linen/linen_test.py | 36 ++-- tests/linen/linen_transforms_test.py | 124 ++++++------ tests/linen/partitioning_test.py | 22 +-- tests/linen/summary_test.py | 26 +-- tests/linen/toplevel_test.py | 10 +- tests/serialization_test.py | 14 +- tests/traceback_util_test.py | 6 +- 113 files changed, 612 insertions(+), 629 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 162f3a9dd8..eac3d98ef7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,11 @@ vNext - - - -- +- Use new typed PRNG keys throughout flax: this essentially involved changing + uses of `jax.random.PRNGKey` to `jax.random.key`. + (See [JEP 9263](https://github.com/google/jax/pull/17297) for details). + If you notice dispatch performance regressions after this change, be sure + you update `jax` to version 0.4.16 or newer. - - - diff --git a/README.md b/README.md index 018a749284..b7ffe0f19c 100644 --- a/README.md +++ b/README.md @@ -119,7 +119,7 @@ class MLP(nn.Module): model = MLP([12, 8, 4]) batch = jnp.ones((32, 10)) -variables = model.init(jax.random.PRNGKey(0), batch) +variables = model.init(jax.random.key(0), batch) output = model.apply(variables, batch) ``` @@ -142,7 +142,7 @@ class CNN(nn.Module): model = CNN() batch = jnp.ones((32, 64, 64, 10)) # (N, H, W, C) format -variables = model.init(jax.random.PRNGKey(0), batch) +variables = model.init(jax.random.key(0), batch) output = model.apply(variables, batch) ``` @@ -174,7 +174,7 @@ model = AutoEncoder(encoder_widths=[20, 10, 5], decoder_widths=[5, 10, 20], input_shape=(12,)) batch = jnp.ones((16, 12)) -variables = model.init(jax.random.PRNGKey(0), batch) +variables = model.init(jax.random.key(0), batch) encoded = model.apply(variables, batch, method=model.encode) decoded = model.apply(variables, encoded, method=model.decode) ``` diff --git a/docs/developer_notes/lift.md b/docs/developer_notes/lift.md index cc32251f0a..b623a820c3 100644 --- a/docs/developer_notes/lift.md +++ b/docs/developer_notes/lift.md @@ -85,7 +85,7 @@ class ManualVmapMLP(nn.Module): return apply_fn({'params': mlp_params}, xs) xs = jnp.ones((3, 4)) -variables = ManualVmapMLP().init(random.PRNGKey(0), xs) +variables = ManualVmapMLP().init(random.key(0), xs) print(jax.tree_util.tree_map(jnp.shape, variables['params'])) """==> { @@ -270,7 +270,7 @@ def lift_transpose(fn, target='params', variables=True, rngs=True): rng_filters=(rngs,)) x = jnp.ones((3, 2)) -y, params = init(lift_transpose(core_nn.dense))(random.PRNGKey(0), x, 4) +y, params = init(lift_transpose(core_nn.dense))(random.key(0), x, 4) ``` NOTE that most users should not need to interact with `pack` directly. @@ -310,7 +310,7 @@ class LinenVmapMLP(nn.Module): VmapMLP = nn.vmap(MLP, variable_axes={'params': 0}, split_rngs={'params': True}, in_axes=0) return VmapMLP(name='mlp')(xs) -variables = LinenVmapMLP().init(random.PRNGKey(0), xs) +variables = LinenVmapMLP().init(random.key(0), xs) print(jax.tree_util.tree_map(jnp.shape, variables['params'])) """==> { @@ -346,7 +346,7 @@ class LinenStatefulVmapMLP(nn.Module): def __call__(self, xs, *, train): VmapMLP = nn.vmap(StatefulMLP, variable_axes={'params': 0, 'batch_stats': 0}, split_rngs={'params': True}, in_axes=0) return VmapMLP(name='mlp')(xs, train=train) -variables = LinenStatefulVmapMLP().init(random.PRNGKey(0), xs) +variables = LinenStatefulVmapMLP().init(random.key(0), xs) ``` All we had to add to `nn.vmap` is `'batch_stats': 0`, indicating that the batch stats are vectorized rather than shared along the first axis. diff --git a/docs/developer_notes/module_lifecycle.rst b/docs/developer_notes/module_lifecycle.rst index 8a28f8a38e..92077423f2 100644 --- a/docs/developer_notes/module_lifecycle.rst +++ b/docs/developer_notes/module_lifecycle.rst @@ -59,7 +59,7 @@ Now we want to construct and use the ``MLP`` Module: mlp = MLP(hidden_size=5, out_size=3) x = jax.numpy.ones((1, 2)) - variables = mlp.init(random.PRNGKey(0), x) + variables = mlp.init(random.key(0), x) y = mlp.apply(variables, x) @@ -70,8 +70,8 @@ Let's take a closer look at initialization. Surprisingly, there actually is no s .. testcode:: - # equivalent to: variables = mlp.init(random.PRNGKey(0), x) - _, variables = mlp.apply({}, x, rngs={"params": random.PRNGKey(0)}, mutable=True) + # equivalent to: variables = mlp.init(random.key(0), x) + _, variables = mlp.apply({}, x, rngs={"params": random.key(0)}, mutable=True) Thus, ``init`` is nothing more than a wrapper around ``apply`` where: @@ -155,7 +155,7 @@ Another benefit of defining submodules and/or variables inline is that you can a mdl = CompactScaledMLP(hidden_size=4, out_size=5) x = jax.numpy.ones((3, 2)) - vars = mdl.init(random.PRNGKey(0), x) + vars = mdl.init(random.key(0), x) assert vars["params"]["scale"].shape == (2,) Many of the standard Linen Modules like ``nn.Dense`` use shape inference already to avoid the need to specify input shapes (like the number of input features to a Dense layer). @@ -207,7 +207,7 @@ The latter is done as follows: return mdl(z, "decode") mdl = CorrectModule() - vars = nn.init(init_fn, mdl)(random.PRNGKey(0)) + vars = nn.init(init_fn, mdl)(random.key(0)) assert vars["params"]["Dense_0"]["kernel"].shape == (2, 8) assert vars["params"]["Dense_1"]["kernel"].shape == (8, 4) @@ -348,7 +348,7 @@ Function closure is the most common way to accidentally hide a JAX array or Line x = jax.numpy.ones((3, 2)) mdl = Foo() - vars = mdl.init(random.PRNGKey(0), x) + vars = mdl.init(random.key(0), x) assert vars['params']['Dense_0']['kernel'].shape == (3, 2, 2) diff --git a/docs/flip/1009-optimizer-api.md b/docs/flip/1009-optimizer-api.md index fd9d681fd6..f658db5ccf 100644 --- a/docs/flip/1009-optimizer-api.md +++ b/docs/flip/1009-optimizer-api.md @@ -496,7 +496,7 @@ def get_learning_rate(step): model = Model() -rng = jax.random.PRNGKey(0) +rng = jax.random.key(0) ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16) batch = next(iter(ds)) variables = model.init(rng, jnp.array(batch['image'][:1])) diff --git a/docs/flip/2396-rnn.md b/docs/flip/2396-rnn.md index d957c8d03a..94a7e78a16 100644 --- a/docs/flip/2396-rnn.md +++ b/docs/flip/2396-rnn.md @@ -18,7 +18,7 @@ def __call__(self, x): nn.LSTMCell, variable_broadcast="params", split_rngs={"params": False} ) carry = LSTM.initialize_carry( - jax.random.PRNGKey(0), batch_dims=x.shape[:1], size=self.hidden_size + jax.random.key(0), batch_dims=x.shape[:1], size=self.hidden_size ) carry, x = LSTM()(carry, x) return x @@ -91,7 +91,7 @@ Where: * `initial_carry`: the initial carry, if not provided it will be initialized using the cell's :meth:`RNNCellBase.initialize_carry` method. * `init_key`: a PRNG key used to initialize the carry, if not provided - ``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this + ``jax.random.key(0)`` will be used. Most cells will ignore this argument. * `seq_lengths`: an optional integer array of shape ``(*batch)`` indicating the length of each sequence, elements whose index in the time dimension diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 0c4e91c2af..a589627569 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -223,7 +223,7 @@ "import jax.numpy as jnp # JAX NumPy\n", "\n", "cnn = CNN()\n", - "print(cnn.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1))))" + "print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1))))" ] }, { @@ -521,7 +521,7 @@ }, "outputs": [], "source": [ - "init_rng = jax.random.PRNGKey(0)" + "init_rng = jax.random.key(0)" ] }, { diff --git a/docs/getting_started.md b/docs/getting_started.md index fec311bced..19a8701e9b 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -131,7 +131,7 @@ import jax import jax.numpy as jnp # JAX NumPy cnn = CNN() -print(cnn.tabulate(jax.random.PRNGKey(0), jnp.ones((1, 28, 28, 1)))) +print(cnn.tabulate(jax.random.key(0), jnp.ones((1, 28, 28, 1)))) ``` +++ {"id": "4b5ac16e"} @@ -332,7 +332,7 @@ executionInfo: timestamp: 1673483485436 id: e4f6f4d3 --- -init_rng = jax.random.PRNGKey(0) +init_rng = jax.random.key(0) ``` +++ {"id": "80fbb60b"} diff --git a/docs/guides/batch_norm.rst b/docs/guides/batch_norm.rst index c1c2c18f9e..65c95c2e3c 100644 --- a/docs/guides/batch_norm.rst +++ b/docs/guides/batch_norm.rst @@ -81,7 +81,7 @@ The ``batch_stats`` collection must be extracted from the ``variables`` for late mlp = MLP() x = jnp.ones((1, 3)) - variables = mlp.init(jax.random.PRNGKey(0), x) + variables = mlp.init(jax.random.key(0), x) params = variables['params'] @@ -89,7 +89,7 @@ The ``batch_stats`` collection must be extracted from the ``variables`` for late --- mlp = MLP() x = jnp.ones((1, 3)) - variables = mlp.init(jax.random.PRNGKey(0), x, train=False) #! + variables = mlp.init(jax.random.key(0), x, train=False) #! params = variables['params'] batch_stats = variables['batch_stats'] #! diff --git a/docs/guides/convert_pytorch_to_flax.rst b/docs/guides/convert_pytorch_to_flax.rst index 0d89858876..ff4e58acd3 100644 --- a/docs/guides/convert_pytorch_to_flax.rst +++ b/docs/guides/convert_pytorch_to_flax.rst @@ -31,7 +31,7 @@ and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the tr # [outC, inC] -> [inC, outC] kernel = jnp.transpose(kernel, (1, 0)) - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (1, 3)) variables = {'params': {'kernel': kernel, 'bias': bias}} @@ -62,7 +62,7 @@ and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will d # [outC, inC, kH, kW] -> [kH, kW, inC, outC] kernel = jnp.transpose(kernel, (2, 3, 1, 0)) - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) variables = {'params': {'kernel': kernel, 'bias': bias}} @@ -154,7 +154,7 @@ Other than the transpose operation before reshaping, we can convert the weights variables = {'params': {'conv': {'kernel': conv_kernel, 'bias': conv_bias}, 'fc': {'kernel': fc_kernel, 'bias': fc_bias}}} - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_out = j_model.apply(variables, x) @@ -192,7 +192,7 @@ while Flax multiplies the estimated statistic with ``momentum`` and the new obse variables = {'params': {'scale': scale, 'bias': bias}, 'batch_stats': {'mean': mean, 'var': var}} - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_bn = nn.BatchNorm(momentum=0.9, use_running_average=True) @@ -241,7 +241,7 @@ operation. ``nn.pool()`` is the core function behind |nn.avg_pool()|_ and |nn.ma return y - key = random.PRNGKey(0) + key = random.key(0) x = random.normal(key, (1, 6, 6, 3)) j_out = avg_pool(x, window_shape=(2, 2), strides=(1, 1), padding=((1, 1), (1, 1))) diff --git a/docs/guides/dropout.rst b/docs/guides/dropout.rst index 2f48594da8..fa9345082f 100644 --- a/docs/guides/dropout.rst +++ b/docs/guides/dropout.rst @@ -27,7 +27,7 @@ desirable properties for neural networks. To learn more, refer to the `Pseudorandom numbers in JAX tutorial `__. **Note:** Recall that JAX has an explicit way of giving you PRNG keys: -you can fork the main PRNG state (such as ``key = jax.random.PRNGKey(seed=0)``) +you can fork the main PRNG state (such as ``key = jax.random.key(seed=0)``) into multiple new PRNG keys with ``key, subkey = jax.random.split(key)``. You can refresh your memory in `🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys `__. @@ -41,10 +41,10 @@ into three keys, including one for Flax Linen ``Dropout``. :title_right: With Dropout :sync: - root_key = jax.random.PRNGKey(seed=0) + root_key = jax.random.key(seed=0) main_key, params_key = jax.random.split(key=root_key) --- - root_key = jax.random.PRNGKey(seed=0) + root_key = jax.random.key(seed=0) main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3) #! **Note:** In Flax, you provide *PRNG streams* with *names*, so that you can use them later diff --git a/docs/guides/ensembling.rst b/docs/guides/ensembling.rst index 3f3d04cb93..89b2dbcdfd 100644 --- a/docs/guides/ensembling.rst +++ b/docs/guides/ensembling.rst @@ -224,7 +224,7 @@ directly. train_ds, test_ds = get_datasets() #! - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) rng, init_rng = jax.random.split(rng) state = create_train_state(init_rng, learning_rate, momentum) #! @@ -246,7 +246,7 @@ directly. --- train_ds, test_ds = get_datasets() test_ds = jax_utils.replicate(test_ds) #! - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) rng, init_rng = jax.random.split(rng) state = create_train_state(jax.random.split(init_rng, jax.device_count()), #! diff --git a/docs/guides/extracting_intermediates.rst b/docs/guides/extracting_intermediates.rst index aa80face44..34b98d1685 100644 --- a/docs/guides/extracting_intermediates.rst +++ b/docs/guides/extracting_intermediates.rst @@ -124,7 +124,7 @@ Note that, by default ``sow`` appends values every time it is called: return output, features batch = jnp.ones((1,28,28,1)) - variables = init(jax.random.PRNGKey(0), batch) + variables = init(jax.random.key(0), batch) preds, feats = predict(variables, batch) assert len(feats) == 2 # Tuple with two values since module was called twice. @@ -180,7 +180,7 @@ avoid using ``nn.compact`` altogether. return RefactoredCNN().apply({"params": params}, x, method=lambda module, x: module.features(x)) - params = init(jax.random.PRNGKey(0), batch) + params = init(jax.random.key(0), batch) features(params, batch) @@ -209,7 +209,7 @@ In the following code example we check if any intermediate activations are non-f fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates) return y, fin - variables = init(jax.random.PRNGKey(0), batch) + variables = init(jax.random.key(0), batch) y, is_finite = predict(variables, batch) all_finite = all(jax.tree_util.tree_leaves(is_finite)) assert all_finite, "non-finite intermediate detected!" @@ -250,8 +250,8 @@ non-layer intermediates, but the filter function won't be applied to it. def predict(params, x): return Model().apply({"params": params}, x, capture_intermediates=True) - batch = jax.random.uniform(jax.random.PRNGKey(1), (1,3)) - params = init(jax.random.PRNGKey(0), batch) + batch = jax.random.uniform(jax.random.key(1), (1,3)) + params = init(jax.random.key(0), batch) preds, feats = predict(params, batch) feats # intermediate c in Model was not stored because it's not a Flax layer --- @@ -276,8 +276,8 @@ non-layer intermediates, but the filter function won't be applied to it. filter_fn = lambda mdl, method_name: isinstance(mdl.name, str) and (mdl.name in {'Dense_0', 'Dense_2'}) #! return Model().apply({"params": params}, x, capture_intermediates=filter_fn) #! - batch = jax.random.uniform(jax.random.PRNGKey(1), (1,3)) - params = init(jax.random.PRNGKey(0), batch) + batch = jax.random.uniform(jax.random.key(1), (1,3)) + params = init(jax.random.key(0), batch) preds, feats = predict(params, batch) feats # intermediate c in Model is stored and isn't filtered out by the filter function #! @@ -337,7 +337,7 @@ your model more explicitly. return Sequential(SeqCNN().layers[0:7]).apply({"params": params}, x) batch = jnp.ones((1,28,28,1)) - params = init(jax.random.PRNGKey(0), batch) + params = init(jax.random.key(0), batch) features(params, batch) Extracting gradients of intermediate values @@ -367,7 +367,7 @@ the model: y = jnp.empty((1, 2)) # random data model = Model() - variables = model.init(jax.random.PRNGKey(1), x) + variables = model.init(jax.random.key(1), x) params, perturbations = variables['params'], variables['perturbations'] Finally compute the gradients of the loss with respect to the perturbations, diff --git a/docs/guides/flax_basics.ipynb b/docs/guides/flax_basics.ipynb index e9b4df0192..b76289be82 100644 --- a/docs/guides/flax_basics.ipynb +++ b/docs/guides/flax_basics.ipynb @@ -147,7 +147,7 @@ } ], "source": [ - "key1, key2 = random.split(random.PRNGKey(0))\n", + "key1, key2 = random.split(random.key(0))\n", "x = random.normal(key1, (10,)) # Dummy input data\n", "params = model.init(key2, x) # Initialization call\n", "jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes" @@ -241,7 +241,7 @@ "y_dim = 5\n", "\n", "# Generate random ground truth W and b.\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "k1, k2 = random.split(key)\n", "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", @@ -597,7 +597,7 @@ " x = nn.relu(x)\n", " return x\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitMLP(features=[3,4,5])\n", @@ -699,7 +699,7 @@ " # the default autonames would be \"Dense_0\", \"Dense_1\", ...\n", " return x\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleMLP(features=[3,4,5])\n", @@ -781,7 +781,7 @@ " y = y + bias\n", " return y\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleDense(features=3)\n", @@ -874,7 +874,7 @@ " return x - ra_mean.value + bias\n", "\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = jnp.ones((10,5))\n", "model = BiasAdderWithRunningMean()\n", "variables = model.init(key1, x)\n", @@ -993,7 +993,7 @@ " return opt_state, params, state\n", "\n", "x = jnp.ones((10,5))\n", - "variables = model.init(random.PRNGKey(0), x)\n", + "variables = model.init(random.key(0), x)\n", "state, params = flax.core.pop(variables, 'params')\n", "del variables\n", "tx = optax.sgd(learning_rate=0.02)\n", diff --git a/docs/guides/flax_basics.md b/docs/guides/flax_basics.md index dfaa012052..1df7944414 100644 --- a/docs/guides/flax_basics.md +++ b/docs/guides/flax_basics.md @@ -79,7 +79,7 @@ Parameters are not stored with the models themselves. You need to initialize par :id: K529lhzeYtl8 :outputId: 06feb9d2-db50-4f41-c169-6df4336f43a5 -key1, key2 = random.split(random.PRNGKey(0)) +key1, key2 = random.split(random.key(0)) x = random.normal(key1, (10,)) # Dummy input data params = model.init(key2, x) # Initialization call jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes @@ -127,7 +127,7 @@ x_dim = 10 y_dim = 5 # Generate random ground truth W and b. -key = random.PRNGKey(0) +key = random.key(0) k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) @@ -299,7 +299,7 @@ class ExplicitMLP(nn.Module): x = nn.relu(x) return x -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = ExplicitMLP(features=[3,4,5]) @@ -355,7 +355,7 @@ class SimpleMLP(nn.Module): # the default autonames would be "Dense_0", "Dense_1", ... return x -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleMLP(features=[3,4,5]) @@ -400,7 +400,7 @@ class SimpleDense(nn.Module): y = y + bias return y -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleDense(features=3) @@ -456,7 +456,7 @@ class BiasAdderWithRunningMean(nn.Module): return x - ra_mean.value + bias -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = jnp.ones((10,5)) model = BiasAdderWithRunningMean() variables = model.init(key1, x) @@ -508,7 +508,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state): return opt_state, params, state x = jnp.ones((10,5)) -variables = model.init(random.PRNGKey(0), x) +variables = model.init(random.key(0), x) state, params = flax.core.pop(variables, 'params') del variables tx = optax.sgd(learning_rate=0.02) diff --git a/docs/guides/flax_on_pjit.ipynb b/docs/guides/flax_on_pjit.ipynb index a1d4d2b5c8..6ec633772a 100644 --- a/docs/guides/flax_on_pjit.ipynb +++ b/docs/guides/flax_on_pjit.ipynb @@ -323,7 +323,7 @@ "# Create fake inputs.\n", "x = jnp.ones((BATCH, DEPTH))\n", "# Initialize a PRNG key.\n", - "k = random.PRNGKey(0)\n", + "k = random.key(0)\n", "\n", "# Create an Optax optimizer.\n", "optimizer = optax.adam(learning_rate=0.001)\n", diff --git a/docs/guides/flax_on_pjit.md b/docs/guides/flax_on_pjit.md index a4e7f9f9cc..3b53875522 100644 --- a/docs/guides/flax_on_pjit.md +++ b/docs/guides/flax_on_pjit.md @@ -218,7 +218,7 @@ BATCH, LAYERS, DEPTH, USE_SCAN = 8, 4, 1024, False # Create fake inputs. x = jnp.ones((BATCH, DEPTH)) # Initialize a PRNG key. -k = random.PRNGKey(0) +k = random.key(0) # Create an Optax optimizer. optimizer = optax.adam(learning_rate=0.001) diff --git a/docs/guides/haiku_migration_guide.rst b/docs/guides/haiku_migration_guide.rst index c0c6097a70..58f14380ea 100644 --- a/docs/guides/haiku_migration_guide.rst +++ b/docs/guides/haiku_migration_guide.rst @@ -9,7 +9,7 @@ and highlight the differences between the two libraries. import jax import jax.numpy as jnp - from jax.random import PRNGKey + from jax import random import optax import flax.linen as nn @@ -106,7 +106,7 @@ and ``apply`` methods. In Flax, you simply instantiate your Module. model = Model(256, 10) To get the model parameters in both libraries you use the ``init`` method -with a ``PRNGKey`` plus some inputs to run the model. The main difference here is +with a ``random.key`` plus some inputs to run the model. The main difference here is that Flax returns a mapping from collection names to nested array dictionaries, ``params`` is just one of these possible collections. In Haiku, you get the ``params`` structure directly. @@ -118,7 +118,7 @@ structure directly. sample_x = jax.numpy.ones((1, 784)) params = model.init( - PRNGKey(0), + random.key(0), sample_x, training=False # <== inputs ) ... @@ -127,7 +127,7 @@ structure directly. sample_x = jax.numpy.ones((1, 784)) variables = model.init( - PRNGKey(0), + random.key(0), sample_x, training=False # <== inputs ) params = variables["params"] @@ -221,13 +221,13 @@ the random dropout masks. .. testcode:: :hide: - train_step(PRNGKey(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) + train_step(random.key(0), params, sample_x, jnp.ones((1,), dtype=jnp.int32)) The most notable differences is that in Flax you have to pass the parameters inside a dictionary with a ``params`` key, and the -PRNGKey inside a dictionary with a ``dropout`` key. This is because in Flax +key inside a dictionary with a ``dropout`` key. This is because in Flax you can have many types of model state and random state. In Haiku, you -just pass the parameters and the PRNGKey directly. +just pass the parameters and the key directly. Handling State ----------------- @@ -310,7 +310,7 @@ of a Haiku model with an ``hk.BatchNorm`` layer. In Flax, we can set sample_x = jax.numpy.ones((1, 784)) params, state = model.init( - PRNGKey(0), + random.key(0), sample_x, training=True # <== inputs #! ) ... @@ -319,7 +319,7 @@ of a Haiku model with an ``hk.BatchNorm`` layer. In Flax, we can set sample_x = jax.numpy.ones((1, 784)) variables = model.init( - PRNGKey(0), #! + random.key(0), #! sample_x, training=False # <== inputs ) params, batch_stats = variables["params"], variables["batch_stats"] @@ -490,7 +490,7 @@ method. This will create all the necessary parameters for the model. :sync: params = model.init( - PRNGKey(0), + random.key(0), x=jax.numpy.ones((1, 784)), ) ... @@ -498,7 +498,7 @@ method. This will create all the necessary parameters for the model. --- variables = model.init( - PRNGKey(0), + random.key(0), x=jax.numpy.ones((1, 784)), ) params = variables["params"] @@ -689,7 +689,7 @@ Finally, let's quickly view how the ``RNN`` Module would be used in both Haiku a model = hk.without_apply_rng(hk.transform(forward)) params = model.init( - PRNGKey(0), + random.key(0), x=jax.numpy.ones((3, 12, 32)), ) @@ -706,7 +706,7 @@ Finally, let's quickly view how the ``RNN`` Module would be used in both Haiku a model = RNN(64) variables = model.init( - PRNGKey(0), + random.key(0), x=jax.numpy.ones((3, 12, 32)), ) params = variables['params'] @@ -813,7 +813,7 @@ we will be specifying that we want to use ``5`` layers each with ``64`` features sample_x = jax.numpy.ones((1, 64)) params = model.init( - PRNGKey(0), + random.key(0), sample_x, training=False # <== inputs ) ... @@ -827,7 +827,7 @@ we will be specifying that we want to use ``5`` layers each with ``64`` features sample_x = jax.numpy.ones((1, 64)) variables = model.init( - PRNGKey(0), + random.key(0), sample_x, training=False # <== inputs ) params = variables['params'] diff --git a/docs/guides/jax_for_the_impatient.ipynb b/docs/guides/jax_for_the_impatient.ipynb index fb2d30f921..71a02386b7 100644 --- a/docs/guides/jax_for_the_impatient.ipynb +++ b/docs/guides/jax_for_the_impatient.ipynb @@ -314,7 +314,7 @@ } ], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key" ] }, @@ -471,7 +471,7 @@ } ], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "def f(x):\n", " return jnp.dot(x.T,x)/2.0\n", "\n", @@ -794,7 +794,7 @@ "y_dim = 5\n", "\n", "# Generate random ground truth W and b.\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "k1, k2 = random.split(key)\n", "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", diff --git a/docs/guides/jax_for_the_impatient.md b/docs/guides/jax_for_the_impatient.md index 5d51f910fc..509eb469aa 100644 --- a/docs/guides/jax_for_the_impatient.md +++ b/docs/guides/jax_for_the_impatient.md @@ -136,7 +136,7 @@ In short, you need to explicitly manage the PRNGs (pseudo random number generato :id: 8iz9KGF4s7nN :outputId: c5bb1581-090b-42ed-cc42-08436154bc14 -key = random.PRNGKey(0) +key = random.key(0) key ``` @@ -203,7 +203,7 @@ $$\nabla f(x) = x$$ :id: zDOydrLMcIzp :outputId: 580c14ed-d1a3-4f92-c9b9-78d58c87bc76 -key = random.PRNGKey(0) +key = random.key(0) def f(x): return jnp.dot(x.T,x)/2.0 @@ -383,7 +383,7 @@ x_dim = 10 y_dim = 5 # Generate random ground truth W and b. -key = random.PRNGKey(0) +key = random.key(0) k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) diff --git a/docs/guides/lr_schedule.rst b/docs/guides/lr_schedule.rst index 90f0c7d971..fa530c66ab 100644 --- a/docs/guides/lr_schedule.rst +++ b/docs/guides/lr_schedule.rst @@ -222,7 +222,7 @@ And the ``create_train_state`` function: steps_per_epoch = train_ds_size // config.batch_size learning_rate_fn = create_learning_rate_fn(config, config.learning_rate, steps_per_epoch) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) state = create_train_state(rng, config, learning_rate_fn) train_ds = get_dummy_data(config.train_ds_size) diff --git a/docs/guides/model_surgery.ipynb b/docs/guides/model_surgery.ipynb index 77d12edc7c..21bf41e819 100644 --- a/docs/guides/model_surgery.ipynb +++ b/docs/guides/model_surgery.ipynb @@ -92,7 +92,7 @@ " initial_params = CNN().init(key, init_shape)['params']\n", " return initial_params\n", "\n", - "key = jax.random.PRNGKey(0)\n", + "key = jax.random.key(0)\n", "params = get_initial_params(key)\n", "\n", "jax.tree_util.tree_map(jnp.shape, params)" diff --git a/docs/guides/model_surgery.md b/docs/guides/model_surgery.md index 25167257a8..bb53710fa1 100644 --- a/docs/guides/model_surgery.md +++ b/docs/guides/model_surgery.md @@ -64,7 +64,7 @@ def get_initial_params(key): initial_params = CNN().init(key, init_shape)['params'] return initial_params -key = jax.random.PRNGKey(0) +key = jax.random.key(0) params = get_initial_params(key) jax.tree_util.tree_map(jnp.shape, params) diff --git a/docs/guides/optax_update_guide.rst b/docs/guides/optax_update_guide.rst index 559e72aa8d..f2b374878c 100644 --- a/docs/guides/optax_update_guide.rst +++ b/docs/guides/optax_update_guide.rst @@ -27,7 +27,7 @@ https://optax.readthedocs.io/en/latest/optax-101.html ds_train = [batch] get_ds_train = lambda: [batch] model = nn.Dense(1) - variables = model.init(jax.random.PRNGKey(0), batch['image']) + variables = model.init(jax.random.key(0), batch['image']) learning_rate, momentum, weight_decay, grad_clip_norm = .1, .9, 1e-3, 1. loss = lambda params, batch: jnp.array(0.) diff --git a/docs/guides/regular_dict_upgrade_guide.rst b/docs/guides/regular_dict_upgrade_guide.rst index 6e20dea29c..8659ad5a3e 100644 --- a/docs/guides/regular_dict_upgrade_guide.rst +++ b/docs/guides/regular_dict_upgrade_guide.rst @@ -32,7 +32,7 @@ The following are the utility functions and example upgrade patterns: import jax.numpy as jnp x = jnp.empty((1,3)) - variables = flax.core.freeze(nn.Dense(5).init(jax.random.PRNGKey(0), x)) + variables = flax.core.freeze(nn.Dense(5).init(jax.random.key(0), x)) other_variables = jnp.array([1, 1, 1, 1, 1], dtype=jnp.float32) @@ -107,12 +107,12 @@ For example: x = jnp.empty((1,3)) flax.config.update('flax_return_frozendict', True) # set Flax to return FrozenDicts - variables = nn.Dense(5).init(jax.random.PRNGKey(0), x) + variables = nn.Dense(5).init(jax.random.key(0), x) assert isinstance(variables, flax.core.FrozenDict) flax.config.update('flax_return_frozendict', False) # set Flax to return regular dicts - variables = nn.Dense(5).init(jax.random.PRNGKey(0), x) + variables = nn.Dense(5).init(jax.random.key(0), x) assert isinstance(variables, dict) diff --git a/docs/guides/rnncell_upgrade_guide.rst b/docs/guides/rnncell_upgrade_guide.rst index 773d1dbf57..7629c3bb35 100644 --- a/docs/guides/rnncell_upgrade_guide.rst +++ b/docs/guides/rnncell_upgrade_guide.rst @@ -53,18 +53,18 @@ signature only requires a PRNG key and a sample input: :title_right: New :sync: - carry = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), out_features) + carry = nn.LSTMCell.initialize_carry(jax.random.key(0), (batch_size,), out_features) --- - carry = cell.initialize_carry(jax.random.PRNGKey(0), x[:, 0].shape) + carry = cell.initialize_carry(jax.random.key(0), x[:, 0].shape) Here, ``x[:, 0].shape`` represents the input for the cell (without the time dimension). You can also just create the input shape directly when its more convenient: .. testcode:: - carry = cell.initialize_carry(jax.random.PRNGKey(0), (batch_size, in_features)) + carry = cell.initialize_carry(jax.random.key(0), (batch_size, in_features)) Upgrade Patterns @@ -99,7 +99,7 @@ it working, albeit not in the most idiomatic way: @staticmethod def initialize_carry(batch_dims, hidden_size): return nn.OptimizedLSTMCell.initialize_carry( - jax.random.PRNGKey(0), batch_dims, hidden_size) + jax.random.key(0), batch_dims, hidden_size) --- @@ -118,7 +118,7 @@ it working, albeit not in the most idiomatic way: @staticmethod def initialize_carry(batch_dims, hidden_size): return nn.OptimizedLSTMCell(hidden_size, parent=None).initialize_carry( - jax.random.PRNGKey(0), (*batch_dims, hidden_size)) + jax.random.key(0), (*batch_dims, hidden_size)) Notice how in the new version, we have to extract the number of features from the carry during ``__call__``, and use ``parent=None`` during ``initialize_carry`` to avoid some potential @@ -147,11 +147,11 @@ a ``nn.scan``-ed version of the cell in the ``setup`` method: @staticmethod def initialize_carry(batch_dims, hidden_size): return nn.OptimizedLSTMCell.initialize_carry( - jax.random.PRNGKey(0), batch_dims, hidden_size) + jax.random.key(0), batch_dims, hidden_size) model = SimpleLSTM() carry = SimpleLSTM.initialize_carry((batch_size,), out_features) - variables = model.init(jax.random.PRNGKey(0), carry, x) + variables = model.init(jax.random.key(0), carry, x) --- @@ -168,12 +168,12 @@ a ``nn.scan``-ed version of the cell in the ``setup`` method: @nn.compact def __call__(self, x): - carry = self.scan_cell.initialize_carry(jax.random.PRNGKey(0), x[:, 0].shape) + carry = self.scan_cell.initialize_carry(jax.random.key(0), x[:, 0].shape) return self.scan_cell(carry, x)[1] # only return the output model = SimpleLSTM(features=out_features) - variables = model.init(jax.random.PRNGKey(0), x) + variables = model.init(jax.random.key(0), x) Because the ``carry`` can be easily initialized from the sample input, we can move the call to ``initialize_carry`` into the ``__call__`` method, somewhat simplifying the code. diff --git a/docs/guides/state_params.rst b/docs/guides/state_params.rst index fd8181098a..c0864bdd06 100644 --- a/docs/guides/state_params.rst +++ b/docs/guides/state_params.rst @@ -68,7 +68,7 @@ Then we can write the actual training code. .. testcode:: model = BiasAdderWithRunningMean() - variables = model.init(random.PRNGKey(0), dummy_input) + variables = model.init(random.key(0), dummy_input) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources @@ -105,7 +105,7 @@ the :code:`axis_name` argument of :code:`lax.pmean()` directly. # Create some fake data and run only for one epoch for testing. dummy_input = jnp.ones((100,)) - key1, key2 = random.split(random.PRNGKey(0), num=2) + key1, key2 = random.split(random.key(0), num=2) batch_size = 64 X = random.normal(key1, (batch_size, 100)) Y = random.normal(key2, (batch_size, 1)) @@ -171,7 +171,7 @@ dimension. Now we are able to train the model: .. testcode:: model = MLP(hidden_size=10, out_size=1) - variables = model.init(random.PRNGKey(0), dummy_input) + variables = model.init(random.key(0), dummy_input) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources diff --git a/docs/guides/transfer_learning.ipynb b/docs/guides/transfer_learning.ipynb index 35db1f2389..d2bd152304 100644 --- a/docs/guides/transfer_learning.ipynb +++ b/docs/guides/transfer_learning.ipynb @@ -156,7 +156,7 @@ "model = Classifier(num_classes=num_classes, backbone=vision_model)\n", "\n", "x = jnp.empty((1, 224, 224, 3))\n", - "variables = model.init(jax.random.PRNGKey(1), x)\n", + "variables = model.init(jax.random.key(1), x)\n", "params = variables['params']" ] }, diff --git a/docs/guides/transfer_learning.md b/docs/guides/transfer_learning.md index e54eddf95e..d467139267 100644 --- a/docs/guides/transfer_learning.md +++ b/docs/guides/transfer_learning.md @@ -106,7 +106,7 @@ num_classes = 3 model = Classifier(num_classes=num_classes, backbone=vision_model) x = jnp.empty((1, 224, 224, 3)) -variables = model.init(jax.random.PRNGKey(1), x) +variables = model.init(jax.random.key(1), x) params = variables['params'] ``` diff --git a/docs/guides/use_checkpointing.ipynb b/docs/guides/use_checkpointing.ipynb index dca0846a1d..94f313833e 100644 --- a/docs/guides/use_checkpointing.ipynb +++ b/docs/guides/use_checkpointing.ipynb @@ -169,7 +169,7 @@ ], "source": [ "# A simple model with one linear layer.\n", - "key1, key2 = random.split(random.PRNGKey(0))\n", + "key1, key2 = random.split(random.key(0))\n", "x1 = random.normal(key1, (5,)) # A simple JAX array.\n", "model = nn.Dense(features=3)\n", "variables = model.init(key2, x1)\n", diff --git a/docs/guides/use_checkpointing.md b/docs/guides/use_checkpointing.md index c5b38afcd2..9756423beb 100644 --- a/docs/guides/use_checkpointing.md +++ b/docs/guides/use_checkpointing.md @@ -96,7 +96,7 @@ First, create a pytree with many data structures and containers, and play with i ```python id="56dec3f6" outputId="f1856d96-1961-48ed-bb7c-cb63fbaa7567" # A simple model with one linear layer. -key1, key2 = random.split(random.PRNGKey(0)) +key1, key2 = random.split(random.key(0)) x1 = random.normal(key1, (5,)) # A simple JAX array. model = nn.Dense(features=3) variables = model.init(key2, x1) diff --git a/docs/index.rst b/docs/index.rst index 3de791f82f..28f3ffef52 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -110,7 +110,7 @@ Basic usage .. testsetup:: import jax - from jax.random import PRNGKey + from jax import random import flax.linen as nn import jax.numpy as jnp @@ -130,7 +130,7 @@ Basic usage model = MLP(out_dims=10) # instantiate the MLP model x = jnp.empty((4, 28, 28, 1)) # generate random data - variables = model.init(PRNGKey(42), x) # initialize the weights + variables = model.init(random.key(42), x)# initialize the weights y = model.apply(variables, x) # make forward pass ---- diff --git a/docs/notebooks/flax_sharp_bits.ipynb b/docs/notebooks/flax_sharp_bits.ipynb index 76c5a4f45c..3b10ab73c6 100644 --- a/docs/notebooks/flax_sharp_bits.ipynb +++ b/docs/notebooks/flax_sharp_bits.ipynb @@ -53,7 +53,7 @@ "\n", "The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. \n", "\n", - "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.PRNGKey(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n", + "> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers).\n", "\n", "Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to \"pull PRNG keys\". `make_rng` guarantees to provide a unique key each time you call it.\n", "\n", @@ -84,7 +84,7 @@ "source": [ "# Randomness.\n", "seed = 0\n", - "root_key = jax.random.PRNGKey(seed=seed)\n", + "root_key = jax.random.key(seed=seed)\n", "main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)\n", "\n", "# A simple network.\n", diff --git a/docs/notebooks/flax_sharp_bits.md b/docs/notebooks/flax_sharp_bits.md index ac5530270f..eb312f7256 100644 --- a/docs/notebooks/flax_sharp_bits.md +++ b/docs/notebooks/flax_sharp_bits.md @@ -45,7 +45,7 @@ Check out a full example below. The [dropout](https://jmlr.org/papers/volume15/srivastava14a/srivastava14a.pdf) stochastic regularization technique randomly removes hidden and visible units in a network. Dropout is a random operation, requiring a PRNG state, and Flax (like JAX) uses [Threefry](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) PRNG that is splittable. -> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.PRNGKey(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers). +> Note: Recall that JAX has an explicit way of giving you PRNG keys: you can fork the main PRNG state (such as `key = jax.random.key(seed=0)`) into multiple new PRNG keys with `key, subkey = jax.random.split(key)`. Refresh your memory in [🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers). Flax provides an _implicit_ way of handling PRNG key streams via [Flax `Module`](https://flax.readthedocs.io/en/latest/guides/flax_basics.html#module-basics)'s [`flax.linen.Module.make_rng`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.make_rng) helper function. It allows the code in Flax `Module`s (or its sub-`Module`s) to "pull PRNG keys". `make_rng` guarantees to provide a unique key each time you call it. @@ -65,7 +65,7 @@ import flax.linen as nn ```{code-cell} ipython3 # Randomness. seed = 0 -root_key = jax.random.PRNGKey(seed=seed) +root_key = jax.random.key(seed=seed) main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3) # A simple network. diff --git a/docs/notebooks/linen_intro.ipynb b/docs/notebooks/linen_intro.ipynb index ab272e83cd..10fb393216 100644 --- a/docs/notebooks/linen_intro.ipynb +++ b/docs/notebooks/linen_intro.ipynb @@ -161,7 +161,7 @@ ], "source": [ "# Make RNG Keys and a fake input.\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "# provide key and fake input to get initialized variables\n", @@ -299,7 +299,7 @@ " x = nn.relu(x)\n", " return x\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitMLP(features=[3,4,5])\n", @@ -363,7 +363,7 @@ " # x = nn.Dense(feat)(x)\n", " return x\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleMLP(features=[3,4,5])\n", @@ -453,7 +453,7 @@ " y = y + bias\n", " return y\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = SimpleDense(features=3)\n", @@ -524,7 +524,7 @@ " y = y + self.bias\n", " return y\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = ExplicitDense(features_in=4, features=3)\n", @@ -607,7 +607,7 @@ " return counter.value\n", "\n", "\n", - "key1 = random.PRNGKey(0)\n", + "key1 = random.key(0)\n", "\n", "model = Counter()\n", "init_variables = model.init(key1)\n", @@ -737,7 +737,7 @@ " x = nn.BatchNorm(use_running_average=not self.training)(x)\n", " return x\n", "\n", - "key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)\n", + "key1, key2, key3, key4 = random.split(random.key(0), 4)\n", "x = random.uniform(key1, (3,4,4))\n", "\n", "model = Block(features=3, training=True)\n", @@ -835,7 +835,7 @@ " x = nn.relu(x)\n", " return x\n", "\n", - "key1, key2 = random.split(random.PRNGKey(3), 2)\n", + "key1, key2 = random.split(random.key(3), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = MLP(features=[3,4,5])\n", @@ -906,7 +906,7 @@ " x = nn.relu(x)\n", " return x\n", "\n", - "key1, key2 = random.split(random.PRNGKey(3), 2)\n", + "key1, key2 = random.split(random.key(3), 2)\n", "x = random.uniform(key1, (4,4))\n", "\n", "model = RematMLP(features=[3,4,5])\n", @@ -1059,7 +1059,7 @@ " return y.mean(axis=-2)\n", "\n", "\n", - "key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4)\n", + "key1, key2, key3, key4 = random.split(random.key(0), 4)\n", "x = random.uniform(key1, (3, 13, 64))\n", "\n", "model = functools.partial(\n", @@ -1137,13 +1137,13 @@ " split_rngs={'params': False})\n", " lstm = LSTM(self.features, name=\"lstm_cell\")\n", "\n", - " dummy_rng = random.PRNGKey(0)\n", + " dummy_rng = random.key(0)\n", " input_shape = xs[:, 0].shape\n", " init_carry = lstm.initialize_carry(dummy_rng, input_shape)\n", "\n", " return lstm(init_carry, xs)\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "xs = random.uniform(key1, (1, 5, 2))\n", "\n", "model = SimpleScan(2)\n", diff --git a/docs/notebooks/linen_intro.md b/docs/notebooks/linen_intro.md index b3c94187fa..aed5511a33 100644 --- a/docs/notebooks/linen_intro.md +++ b/docs/notebooks/linen_intro.md @@ -83,7 +83,7 @@ We call the `init` method on the instantiated Module. If the Module `__call__` :outputId: 3adfaeaf-977e-4e82-8adf-d254fae6eb91 # Make RNG Keys and a fake input. -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) # provide key and fake input to get initialized variables @@ -147,7 +147,7 @@ class ExplicitMLP(nn.Module): x = nn.relu(x) return x -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = ExplicitMLP(features=[3,4,5]) @@ -182,7 +182,7 @@ class SimpleMLP(nn.Module): # x = nn.Dense(feat)(x) return x -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleMLP(features=[3,4,5]) @@ -233,7 +233,7 @@ class SimpleDense(nn.Module): y = y + bias return y -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = SimpleDense(features=3) @@ -271,7 +271,7 @@ class ExplicitDense(nn.Module): y = y + self.bias return y -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (4,4)) model = ExplicitDense(features_in=4, features=3) @@ -316,7 +316,7 @@ class Counter(nn.Module): return counter.value -key1 = random.PRNGKey(0) +key1 = random.key(0) model = Counter() init_variables = model.init(key1) @@ -351,7 +351,7 @@ class Block(nn.Module): x = nn.BatchNorm(use_running_average=not self.training)(x) return x -key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4) +key1, key2, key3, key4 = random.split(random.key(0), 4) x = random.uniform(key1, (3,4,4)) model = Block(features=3, training=True) @@ -412,7 +412,7 @@ class MLP(nn.Module): x = nn.relu(x) return x -key1, key2 = random.split(random.PRNGKey(3), 2) +key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4,4)) model = MLP(features=[3,4,5]) @@ -452,7 +452,7 @@ class RematMLP(nn.Module): x = nn.relu(x) return x -key1, key2 = random.split(random.PRNGKey(3), 2) +key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4,4)) model = RematMLP(features=[3,4,5]) @@ -577,7 +577,7 @@ class MultiHeadDotProductAttention(nn.Module): return y.mean(axis=-2) -key1, key2, key3, key4 = random.split(random.PRNGKey(0), 4) +key1, key2, key3, key4 = random.split(random.key(0), 4) x = random.uniform(key1, (3, 13, 64)) model = functools.partial( @@ -623,13 +623,13 @@ class SimpleScan(nn.Module): split_rngs={'params': False}) lstm = LSTM(self.features, name="lstm_cell") - dummy_rng = random.PRNGKey(0) + dummy_rng = random.key(0) input_shape = xs[:, 0].shape init_carry = lstm.initialize_carry(dummy_rng, input_shape) return lstm(init_carry, xs) -key1, key2 = random.split(random.PRNGKey(0), 2) +key1, key2 = random.split(random.key(0), 2) xs = random.uniform(key1, (1, 5, 2)) model = SimpleScan(2) diff --git a/docs/notebooks/optax_update_guide.ipynb b/docs/notebooks/optax_update_guide.ipynb index 1b18a507e5..0f48400443 100644 --- a/docs/notebooks/optax_update_guide.ipynb +++ b/docs/notebooks/optax_update_guide.ipynb @@ -125,7 +125,7 @@ " return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))\n", "\n", "model = Perceptron([50, 10])\n", - "variables = model.init(jax.random.PRNGKey(0), batch['image'])\n", + "variables = model.init(jax.random.key(0), batch['image'])\n", "\n", "jax.tree_util.tree_map(jnp.shape, variables)" ] diff --git a/docs/notebooks/optax_update_guide.md b/docs/notebooks/optax_update_guide.md index 2ad9a6e8c8..3d77d0f153 100644 --- a/docs/notebooks/optax_update_guide.md +++ b/docs/notebooks/optax_update_guide.md @@ -71,7 +71,7 @@ def loss(params, batch): return jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) model = Perceptron([50, 10]) -variables = model.init(jax.random.PRNGKey(0), batch['image']) +variables = model.init(jax.random.key(0), batch['image']) jax.tree_util.tree_map(jnp.shape, variables) ``` diff --git a/docs/notebooks/state_params.ipynb b/docs/notebooks/state_params.ipynb index 580b254fac..fb625cefdf 100644 --- a/docs/notebooks/state_params.ipynb +++ b/docs/notebooks/state_params.ipynb @@ -63,12 +63,12 @@ "# Initialize random variables\n", "dummy_input = jnp.ones((32, 5))\n", "\n", - "X = random.uniform(random.PRNGKey(0), (128, 5), minval=0.0, maxval=1.0)\n", - "noise = random.uniform(random.PRNGKey(0), (), minval=0.0, maxval=0.1)\n", + "X = random.uniform(random.key(0), (128, 5), minval=0.0, maxval=1.0)\n", + "noise = random.uniform(random.key(0), (), minval=0.0, maxval=0.1)\n", "X += noise\n", "\n", - "W = random.uniform(random.PRNGKey(0), (5, 1), minval=0.0, maxval=1.0)\n", - "b = random.uniform(random.PRNGKey(0), (), minval=0.0, maxval=1.0)\n", + "W = random.uniform(random.key(0), (5, 1), minval=0.0, maxval=1.0)\n", + "b = random.uniform(random.key(0), (), minval=0.0, maxval=1.0)\n", "\n", "Y = jnp.matmul(X, W) + b\n", "\n", @@ -151,7 +151,7 @@ "outputs": [], "source": [ "model = BiasAdderWithRunningMean()\n", - "variables = model.init(random.PRNGKey(0), dummy_input)\n", + "variables = model.init(random.key(0), dummy_input)\n", "# Split state and params (which are updated by optimizer).\n", "state, params = flax.core.pop(variables, 'params')\n", "del variables # Delete variables to avoid wasting resources\n", @@ -270,7 +270,7 @@ "outputs": [], "source": [ "model = MLP(hidden_size=10, out_size=1)\n", - "variables = model.init(random.PRNGKey(0), dummy_input)\n", + "variables = model.init(random.key(0), dummy_input)\n", "# Split state and params (which are updated by optimizer).\n", "state, params = flax.core.pop(variables, 'params')\n", "del variables # Delete variables to avoid wasting resources\n", diff --git a/docs/notebooks/state_params.md b/docs/notebooks/state_params.md index b1fb9dce31..3c16370be4 100644 --- a/docs/notebooks/state_params.md +++ b/docs/notebooks/state_params.md @@ -49,12 +49,12 @@ from flax import linen as nn # Initialize random variables dummy_input = jnp.ones((32, 5)) -X = random.uniform(random.PRNGKey(0), (128, 5), minval=0.0, maxval=1.0) -noise = random.uniform(random.PRNGKey(0), (), minval=0.0, maxval=0.1) +X = random.uniform(random.key(0), (128, 5), minval=0.0, maxval=1.0) +noise = random.uniform(random.key(0), (), minval=0.0, maxval=0.1) X += noise -W = random.uniform(random.PRNGKey(0), (5, 1), minval=0.0, maxval=1.0) -b = random.uniform(random.PRNGKey(0), (), minval=0.0, maxval=1.0) +W = random.uniform(random.key(0), (5, 1), minval=0.0, maxval=1.0) +b = random.uniform(random.key(0), (), minval=0.0, maxval=1.0) Y = jnp.matmul(X, W) + b @@ -112,7 +112,7 @@ Then we can write the actual training code. :id: 8RUFi57GVktj model = BiasAdderWithRunningMean() -variables = model.init(random.PRNGKey(0), dummy_input) +variables = model.init(random.key(0), dummy_input) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources @@ -201,7 +201,7 @@ Note that we also need to specify that the model state does not have a batch dim :id: OdMTQcMoYUtk model = MLP(hidden_size=10, out_size=1) -variables = model.init(random.PRNGKey(0), dummy_input) +variables = model.init(random.key(0), dummy_input) # Split state and params (which are updated by optimizer). state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources diff --git a/examples/imagenet/imagenet.ipynb b/examples/imagenet/imagenet.ipynb index 28f5bb82b8..2614653ef8 100644 --- a/examples/imagenet/imagenet.ipynb +++ b/examples/imagenet/imagenet.ipynb @@ -853,7 +853,7 @@ "learning_rate_fn = train.create_learning_rate_fn(\n", " config, base_learning_rate, steps_per_epoch)\n", "state = train.create_train_state(\n", - " jax.random.PRNGKey(0), config, model, image_size=input_pipeline.IMAGE_SIZE,\n", + " jax.random.key(0), config, model, image_size=input_pipeline.IMAGE_SIZE,\n", " learning_rate_fn=learning_rate_fn)\n", "state = train.restore_checkpoint(state, './')" ] diff --git a/examples/imagenet/models_test.py b/examples/imagenet/models_test.py index efe5118769..338df46b2b 100644 --- a/examples/imagenet/models_test.py +++ b/examples/imagenet/models_test.py @@ -31,7 +31,7 @@ class ResNetV1Test(parameterized.TestCase): def test_resnet_v1_model(self): """Tests ResNet V1 model definition and output (variables).""" - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) model_def = models.ResNet50(num_classes=10, dtype=jnp.float32) variables = model_def.init(rng, jnp.ones((8, 224, 224, 3), jnp.float32)) @@ -45,7 +45,7 @@ def test_resnet_v1_model(self): @parameterized.product(model=(models.ResNet18, models.ResNet18Local)) def test_resnet_18_v1_model(self, model): """Tests ResNet18 V1 model definition and output (variables).""" - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) model_def = model(num_classes=2, dtype=jnp.float32) variables = model_def.init(rng, jnp.ones((1, 64, 64, 3), jnp.float32)) diff --git a/examples/imagenet/train.py b/examples/imagenet/train.py index 034bca6019..b83b97de68 100644 --- a/examples/imagenet/train.py +++ b/examples/imagenet/train.py @@ -288,7 +288,7 @@ def train_and_evaluate( logdir=workdir, just_logging=jax.process_index() != 0 ) - rng = random.PRNGKey(0) + rng = random.key(0) image_size = 224 @@ -427,6 +427,6 @@ def train_and_evaluate( save_checkpoint(state, workdir) # Wait until computations are done before exiting - jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready() + jax.random.normal(jax.random.key(0), ()).block_until_ready() return state diff --git a/examples/imagenet/train_test.py b/examples/imagenet/train_test.py index f9cc657d67..076f429b23 100644 --- a/examples/imagenet/train_test.py +++ b/examples/imagenet/train_test.py @@ -44,9 +44,9 @@ def setUp(self): def test_create_model(self): """Tests creating model.""" model = train.create_model(model_cls=models._ResNet1, half_precision=False) # pylint: disable=protected-access - params, batch_stats = train.initialized(random.PRNGKey(0), 224, model) + params, batch_stats = train.initialized(random.key(0), 224, model) variables = {'params': params, 'batch_stats': batch_stats} - x = random.normal(random.PRNGKey(1), (8, 224, 224, 3)) + x = random.normal(random.key(1), (8, 224, 224, 3)) y = model.apply(variables, x, train=False) self.assertEqual(y.shape, (8, 1000)) @@ -58,9 +58,9 @@ def test_create_model_local(self): model = train.create_model( model_cls=models._ResNet1Local, half_precision=False ) # pylint: disable=protected-access - params, batch_stats = train.initialized(random.PRNGKey(0), 64, model) + params, batch_stats = train.initialized(random.key(0), 64, model) variables = {'params': params, 'batch_stats': batch_stats} - x = random.normal(random.PRNGKey(1), (1, 64, 64, 3)) + x = random.normal(random.key(1), (1, 64, 64, 3)) y = model.apply(variables, x, train=False) self.assertEqual(y.shape, (1, 1000)) diff --git a/examples/linen_design_test/attention_simple.py b/examples/linen_design_test/attention_simple.py index 7e38343b9a..e77d43c1b3 100644 --- a/examples/linen_design_test/attention_simple.py +++ b/examples/linen_design_test/attention_simple.py @@ -202,7 +202,7 @@ def __call__(self, inputs_q, inputs_kv, bias=None, dtype=jnp.float32): if __name__ == '__main__': inputs = jnp.ones((8, 97, 256)) - rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} + rngs = {'params': random.key(0), 'dropout': random.key(1)} model = MultiHeadDotProductAttention( broadcast_dropout=False, qkv_features=256, diff --git a/examples/linen_design_test/autoencoder.py b/examples/linen_design_test/autoencoder.py index 1cd8ceba46..7c6a1fc9c6 100644 --- a/examples/linen_design_test/autoencoder.py +++ b/examples/linen_design_test/autoencoder.py @@ -69,7 +69,7 @@ def decode(self, z): # `ae.initialized` returns a materialized copy of `ae` by # running through an input to create submodules defined lazily. -params = ae.init({"params": random.PRNGKey(42)}, jnp.ones((1, 28, 28, 1))) +params = ae.init({"params": random.key(42)}, jnp.ones((1, 28, 28, 1))) # Now you can use `ae` as a normal object, calling any methods defined on AutoEncoder diff --git a/examples/linen_design_test/linear_regression.py b/examples/linen_design_test/linear_regression.py index b0ed964222..8bda1e1112 100644 --- a/examples/linen_design_test/linear_regression.py +++ b/examples/linen_design_test/linear_regression.py @@ -41,7 +41,7 @@ def init_params(rng): # Get initial parameters -params = init_params(jax.random.PRNGKey(42)) +params = init_params(jax.random.key(42)) print("initial params", params) # Run SGD. diff --git a/examples/linen_design_test/mlp_explicit.py b/examples/linen_design_test/mlp_explicit.py index 594145b86f..9953c4df4a 100644 --- a/examples/linen_design_test/mlp_explicit.py +++ b/examples/linen_design_test/mlp_explicit.py @@ -56,7 +56,7 @@ def __call__(self, x): # Return an initialized instance of MLP by only calling `setup`. -rngkey = jax.random.PRNGKey(10) +rngkey = jax.random.key(10) init_variables = MLP().init({'params': rngkey}, jnp.ones((1, 3))) pprint(init_variables) diff --git a/examples/linen_design_test/mlp_inline.py b/examples/linen_design_test/mlp_inline.py index 73d525acff..b631d19d83 100644 --- a/examples/linen_design_test/mlp_inline.py +++ b/examples/linen_design_test/mlp_inline.py @@ -39,7 +39,7 @@ def __call__(self, x): # initializing all variables. # # Variable shapes depend on the input shape passed in. -rngkey = jax.random.PRNGKey(10) +rngkey = jax.random.key(10) model = MLP((2, 1)) x = jnp.ones((1, 3)) mlp_variables = model.init(rngkey, x) diff --git a/examples/linen_design_test/mlp_lazy.py b/examples/linen_design_test/mlp_lazy.py index 7283d483e4..7e246917bf 100644 --- a/examples/linen_design_test/mlp_lazy.py +++ b/examples/linen_design_test/mlp_lazy.py @@ -41,7 +41,7 @@ def __call__(self, x): # initializing all variables. # # Variable shapes depend on the input shape passed in. -rngkey = jax.random.PRNGKey(10) +rngkey = jax.random.key(10) mlp_variables = MLP().init(rngkey, jnp.zeros((1, 3))) pprint(mlp_variables) diff --git a/examples/linen_design_test/tied_autoencoder.py b/examples/linen_design_test/tied_autoencoder.py index 4f49d5647e..a7dadeef19 100644 --- a/examples/linen_design_test/tied_autoencoder.py +++ b/examples/linen_design_test/tied_autoencoder.py @@ -39,7 +39,7 @@ # tae = TiedAutoEncoder(parent=None) # tae = tae.initialized( -# {'params': random.PRNGKey(42)}, +# {'params': random.key(42)}, # jnp.ones((1, 16))) # print("reconstruct", jnp.shape(tae(jnp.ones((1, 16))))) # print("var shapes", jax.tree_util.tree_map(jnp.shape, tae.variables)) diff --git a/examples/linen_design_test/weight_std.py b/examples/linen_design_test/weight_std.py index 24b90c4fe1..c384a00b91 100644 --- a/examples/linen_design_test/weight_std.py +++ b/examples/linen_design_test/weight_std.py @@ -58,5 +58,5 @@ def standardize(x, axis, eps=1e-8): # std_module = StdWeight(module) # return std_module(x) -# m_variables = MyModule().init({'params': jax.random.PRNGKey(10)}, jnp.ones((1, 4))) +# m_variables = MyModule().init({'params': jax.random.key(10)}, jnp.ones((1, 4))) # print(m_variables) diff --git a/examples/lm1b/temperature_sampler_test.py b/examples/lm1b/temperature_sampler_test.py index a0c7f46ece..6ccb079c7e 100644 --- a/examples/lm1b/temperature_sampler_test.py +++ b/examples/lm1b/temperature_sampler_test.py @@ -28,7 +28,7 @@ class TestTemperatureSampler(absltest.TestCase): def test_temperature_sampler(self): tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32) cache = None - key = jax.random.PRNGKey(0) + key = jax.random.key(0) def tokens_to_logits(tokens, cache): jax.debug.print('tokens: {}', tokens) diff --git a/examples/lm1b/train.py b/examples/lm1b/train.py index 5284a8072b..ffd8060a92 100644 --- a/examples/lm1b/train.py +++ b/examples/lm1b/train.py @@ -235,7 +235,7 @@ def predict_step( """Predict language model on a batch.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.TransformerLM(config).init( - jax.random.PRNGKey(0), jnp.ones(target_shape, config.dtype) + jax.random.key(0), jnp.ones(target_shape, config.dtype) ) cache = initial_variables["cache"] @@ -437,7 +437,7 @@ def encode_strings(strs, max_len): predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 - rng = jax.random.PRNGKey(config.seed) + rng = jax.random.key(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) diff --git a/examples/mnist/train.py b/examples/mnist/train.py index 32a6db4448..f9fc4582b6 100644 --- a/examples/mnist/train.py +++ b/examples/mnist/train.py @@ -128,7 +128,7 @@ def train_and_evaluate( The train state (which includes the `.params`). """ train_ds, test_ds = get_datasets() - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) summary_writer = tensorboard.SummaryWriter(workdir) summary_writer.hparams(dict(config)) diff --git a/examples/mnist/train_test.py b/examples/mnist/train_test.py index bf234b0fbe..a0146b98c2 100644 --- a/examples/mnist/train_test.py +++ b/examples/mnist/train_test.py @@ -41,7 +41,7 @@ def setUp(self): def test_cnn(self): """Tests CNN module used as the trainable model.""" - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) inputs = jnp.ones((1, 28, 28, 3), jnp.float32) output, variables = train.CNN().init_with_output(rng, inputs) diff --git a/examples/nlp_seq/train.py b/examples/nlp_seq/train.py index 72ce469507..c32bc0a48d 100644 --- a/examples/nlp_seq/train.py +++ b/examples/nlp_seq/train.py @@ -304,7 +304,7 @@ def main(argv): model = models.Transformer(config) - rng = random.PRNGKey(random_seed) + rng = random.key(random_seed) rng, init_rng = random.split(rng) # call a jitted initialization function to get the initial parameter tree diff --git a/examples/ogbg_molpcba/models_test.py b/examples/ogbg_molpcba/models_test.py index f0dd4a536f..4c4cc30176 100644 --- a/examples/ogbg_molpcba/models_test.py +++ b/examples/ogbg_molpcba/models_test.py @@ -28,8 +28,8 @@ class ModelsTest(parameterized.TestCase): def setUp(self): super().setUp() self.rngs = { - 'params': jax.random.PRNGKey(0), - 'dropout': jax.random.PRNGKey(1), + 'params': jax.random.key(0), + 'dropout': jax.random.key(1), } n_node = jnp.arange(3, 11) n_edge = jnp.arange(4, 12) diff --git a/examples/ogbg_molpcba/train.py b/examples/ogbg_molpcba/train.py index 59d1346f2b..8f0229d0e7 100644 --- a/examples/ogbg_molpcba/train.py +++ b/examples/ogbg_molpcba/train.py @@ -319,7 +319,7 @@ def train_and_evaluate( # Create and initialize the network. logging.info('Initializing network.') - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) rng, init_rng = jax.random.split(rng) init_graphs = next(datasets['train'].as_numpy_iterator()) init_graphs = replace_globals(init_graphs) diff --git a/examples/ogbg_molpcba/train_test.py b/examples/ogbg_molpcba/train_test.py index 268c7ce165..27aeb4df2c 100644 --- a/examples/ogbg_molpcba/train_test.py +++ b/examples/ogbg_molpcba/train_test.py @@ -143,7 +143,7 @@ def setUp(self): print('Running on platform:', platform.upper()) # Create PRNG keys. - self.rng = jax.random.PRNGKey(0) + self.rng = jax.random.key(0) # Create dummy datasets. self.datasets = get_dummy_datasets(dataset_length=20, batch_size=10) diff --git a/examples/ppo/ppo_lib.py b/examples/ppo/ppo_lib.py index 2b7258acf9..eac6b57e35 100644 --- a/examples/ppo/ppo_lib.py +++ b/examples/ppo/ppo_lib.py @@ -325,7 +325,7 @@ def train( config.num_agents * config.actor_steps // config.batch_size ) - initial_params = get_initial_params(jax.random.PRNGKey(0), model) + initial_params = get_initial_params(jax.random.key(0), model) state = create_train_state( initial_params, model, diff --git a/examples/ppo/ppo_lib_test.py b/examples/ppo/ppo_lib_test.py index bf944fbaf8..ac09e4185c 100644 --- a/examples/ppo/ppo_lib_test.py +++ b/examples/ppo/ppo_lib_test.py @@ -105,7 +105,7 @@ def choose_random_outputs(self): def test_model(self): outputs = self.choose_random_outputs() module = models.ActorCritic(num_outputs=outputs) - params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) + params = ppo_lib.get_initial_params(jax.random.key(0), module) test_batch_size, obs_shape = 10, (84, 84, 4) random_input = np.random.random(size=(test_batch_size,) + obs_shape) log_probs, values = agent.policy_action(module.apply, params, random_input) @@ -138,7 +138,7 @@ def test_optimization_step(self): entropy_coeff = 0.01 batch_size = 256 module = models.ActorCritic(num_outputs) - initial_params = ppo_lib.get_initial_params(jax.random.PRNGKey(0), module) + initial_params = ppo_lib.get_initial_params(jax.random.key(0), module) config = ml_collections.ConfigDict({ 'learning_rate': 2.5e-4, 'decaying_lr_and_clip_param': True, diff --git a/examples/seq2seq/seq2seq.ipynb b/examples/seq2seq/seq2seq.ipynb index c9ca74edc4..6ecce14748 100644 --- a/examples/seq2seq/seq2seq.ipynb +++ b/examples/seq2seq/seq2seq.ipynb @@ -696,7 +696,7 @@ "outputs": [], "source": [ "# Using different random seeds generates different samples.\n", - "preds = train.decode(state.params, inputs, jax.random.PRNGKey(0), ctable)" + "preds = train.decode(state.params, inputs, jax.random.key(0), ctable)" ] }, { diff --git a/examples/seq2seq/train.py b/examples/seq2seq/train.py index 6db75b490c..55141a43e8 100644 --- a/examples/seq2seq/train.py +++ b/examples/seq2seq/train.py @@ -206,7 +206,7 @@ def train_and_evaluate(workdir: str) -> train_state.TrainState: # TODO(marcvanzee): Integrate ctable with train_state. ctable = CTable('0123456789+= ', FLAGS.max_len_query_digit) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) state = get_train_state(rng, ctable) writer = metric_writers.create_default_writer(workdir) diff --git a/examples/seq2seq/train_test.py b/examples/seq2seq/train_test.py index a932810294..564b1b61b5 100644 --- a/examples/seq2seq/train_test.py +++ b/examples/seq2seq/train_test.py @@ -40,7 +40,7 @@ def create_train_state(ctable): hidden_size=train.FLAGS.hidden_size, vocab_size=ctable.vocab_size, ) - params = train.get_initial_params(model, jax.random.PRNGKey(0), ctable) + params = train.get_initial_params(model, jax.random.key(0), ctable) tx = optax.adam(train.FLAGS.learning_rate) state = train_state.TrainState.create( apply_fn=model.apply, params=params, tx=tx @@ -89,7 +89,7 @@ def test_train_one_step(self): batch = ctable.get_batch(128) state = create_train_state(ctable) - key = random.PRNGKey(0) + key = random.key(0) _, train_metrics = train.train_step(state, batch, key, ctable.eos_id) self.assertLessEqual(train_metrics['loss'], 5) @@ -98,7 +98,7 @@ def test_train_one_step(self): def test_decode_batch(self): ctable = create_ctable() batch = ctable.get_batch(5) - key = random.PRNGKey(0) + key = random.key(0) state = create_train_state(ctable) train.decode_batch(state, batch, key, ctable) diff --git a/examples/sst2/models.py b/examples/sst2/models.py index f5a4519970..ee92c7a6f2 100644 --- a/examples/sst2/models.py +++ b/examples/sst2/models.py @@ -182,7 +182,7 @@ def __call__(self, carry, x): def initialize_carry(self, input_shape): # Use fixed random key since default state init fn is just zeros. return nn.OptimizedLSTMCell(self.hidden_size, parent=None).initialize_carry( - jax.random.PRNGKey(0), input_shape + jax.random.key(0), input_shape ) diff --git a/examples/sst2/models_test.py b/examples/sst2/models_test.py index e97aa18737..c1a42a0c02 100644 --- a/examples/sst2/models_test.py +++ b/examples/sst2/models_test.py @@ -35,7 +35,7 @@ def test_embedder_returns_correct_output_shape(self): model = models.Embedder( vocab_size=vocab_size, embedding_size=embedding_size ) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) output, _ = model.init_with_output(rng, token_ids, deterministic=True) self.assertEqual((token_ids.shape) + (embedding_size,), output.shape) @@ -47,7 +47,7 @@ def test_lstm_returns_correct_output_shape(self): embedding_size = 4 hidden_size = 5 model = models.SimpleLSTM(5) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) inputs = np.random.RandomState(0).normal( size=[batch_size, seq_len, embedding_size] ) @@ -62,7 +62,7 @@ def test_bilstm_returns_correct_output_shape(self): embedding_size = 4 hidden_size = 5 model = models.SimpleBiLSTM(hidden_size=hidden_size) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) inputs = np.random.RandomState(0).normal( size=[batch_size, seq_len, embedding_size] ) @@ -92,7 +92,7 @@ def test_text_classifier_returns_correct_output_shape(self): deterministic=True, ) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) token_ids = np.array([[2, 4, 3], [2, 6, 3]], dtype=np.int32) lengths = np.array([2, 3], dtype=np.int32) output, _ = model.init_with_output(rng, token_ids, lengths) diff --git a/examples/sst2/train.py b/examples/sst2/train.py index ab6514c62f..543df0d125 100644 --- a/examples/sst2/train.py +++ b/examples/sst2/train.py @@ -264,7 +264,7 @@ def train_and_evaluate( eval_step_fn = jax.jit(eval_step) # Create model and a state that contains the parameters. - rng = jax.random.PRNGKey(config.seed) + rng = jax.random.key(config.seed) model = model_from_config(config) state = create_train_state(rng, config, model) diff --git a/examples/sst2/train_test.py b/examples/sst2/train_test.py index bcd7107cf1..3151771033 100644 --- a/examples/sst2/train_test.py +++ b/examples/sst2/train_test.py @@ -34,7 +34,7 @@ def test_train_step_updates_parameters(self): # Create model and a state that contains the parameters. config = default_config.get_config() config.vocab_size = 13 - rng = jax.random.PRNGKey(config.seed) + rng = jax.random.key(config.seed) model = train.model_from_config(config) state = train.create_train_state(rng, config, model) diff --git a/examples/vae/train.py b/examples/vae/train.py index 273f860115..68f192fc5b 100644 --- a/examples/vae/train.py +++ b/examples/vae/train.py @@ -93,7 +93,7 @@ def eval_model(vae): def train_and_evaluate(config: ml_collections.ConfigDict): """Train and evaulate pipeline.""" - rng = random.PRNGKey(0) + rng = random.key(0) rng, key = random.split(rng) ds_builder = tfds.builder('binarized_mnist') diff --git a/examples/wmt/train.py b/examples/wmt/train.py index 929cdf9c93..f86c7fbcf9 100644 --- a/examples/wmt/train.py +++ b/examples/wmt/train.py @@ -276,7 +276,7 @@ def initialize_cache(inputs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( - jax.random.PRNGKey(0), + jax.random.key(0), jnp.ones(inputs.shape, config.dtype), jnp.ones(target_shape, config.dtype), ) @@ -523,7 +523,7 @@ def decode_tokens(toks): predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 - rng = jax.random.PRNGKey(config.seed) + rng = jax.random.key(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) diff --git a/flax/core/flax_functional_engine.ipynb b/flax/core/flax_functional_engine.ipynb index 18b6c70f34..ce83035cff 100644 --- a/flax/core/flax_functional_engine.ipynb +++ b/flax/core/flax_functional_engine.ipynb @@ -90,7 +90,7 @@ "model_fn = functools.partial(dense, features=3)\n", "\n", "x = jnp.ones((1, 2))\n", - "y, params = init(model_fn)(random.PRNGKey(0), x)\n", + "y, params = init(model_fn)(random.key(0), x)\n", "print(params)\n", "\n", "def mlp(scope: Scope, inputs: Array, features: int):\n", @@ -98,7 +98,7 @@ " hidden = nn.relu(hidden)\n", " return dense(scope.push('out'), hidden, 1)\n", "\n", - "init(mlp)(random.PRNGKey(0), x, features=3)" + "init(mlp)(random.key(0), x, features=3)" ] }, { @@ -144,7 +144,7 @@ " table = scope.param('table', init_fn, (num_embeddings, features))\n", " return Embedding(table)\n", "\n", - "embedding, _ = init(embedding)(random.PRNGKey(0), num_embeddings=2, features=3)\n", + "embedding, _ = init(embedding)(random.key(0), num_embeddings=2, features=3)\n", "print(embedding.table)\n", "print(embedding.lookup(1))\n", "print(embedding.attend(jnp.ones((1, 3,))))" @@ -236,7 +236,7 @@ "\n", "x = jnp.ones((1, 2))\n", "carry = lstm_init_carry((1,), 3)\n", - "y, variables = init(lstm)(random.PRNGKey(0), carry, x)\n", + "y, variables = init(lstm)(random.key(0), carry, x)\n", "jax.tree_util.tree_map(np.shape, (y, variables))" ] }, @@ -269,7 +269,7 @@ " lstm_scan = lift.scan(lstm, in_axes=1, out_axes=1, variable_broadcast='params', split_rngs={'params': False})\n", " return lstm_scan(scope, init_carry, xs)\n", "\n", - "key1, key2 = random.split(random.PRNGKey(0), 2)\n", + "key1, key2 = random.split(random.key(0), 2)\n", "xs = random.uniform(key1, (1, 5, 2))\n", "\n", "\n", diff --git a/flax/core/meta.py b/flax/core/meta.py index 5ad9005ead..b21f2a2ee6 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -209,10 +209,10 @@ def __call__(self, x): # this way we can determinte the PartitionSpecs for the init variables # before we call the init fn. var_spec = nn.get_partition_spec( - jax.eval_shape(mlp.init, random.PRNGKey(0), x)) + jax.eval_shape(mlp.init, random.key(0), x)) init_fn = mesh(pjit(mlp.init, (None, PartitionSpec("data", "model")), var_spec)) - variables = init_fn(random.PRNGKey(0), x) + variables = init_fn(random.key(0), x) apply_fn = mesh(pjit( mlp.apply, (var_spec, PartitionSpec("data", "model")), diff --git a/flax/core/scope.py b/flax/core/scope.py index 0c6b6441ad..aa00f5eb24 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -1143,7 +1143,7 @@ def f(scope, x): k = scope.param("kernel", nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1])) return x @ k init_fn = lazy_init(f) - variables = init_fn(random.PRNGKey(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) + variables = init_fn(random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) Args: diff --git a/flax/cursor.py b/flax/cursor.py index 49561f1bb0..f419403fa9 100644 --- a/flax/cursor.py +++ b/flax/cursor.py @@ -296,7 +296,7 @@ def __call__(self, x): x = nn.relu(x) return x - params = Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] + params = Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] def update_fn(path, value): '''Multiply all dense kernel params by 2 and add 1. @@ -320,7 +320,7 @@ def update_fn(path, value): jax.tree_util.tree_map( lambda x, y: (x == y).all(), params, - Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ + Model().init(jax.random.key(0), jnp.empty((1, 2)))[ 'params' ], ) diff --git a/flax/errors.py b/flax/errors.py index df28287728..0c82064dbb 100644 --- a/flax/errors.py +++ b/flax/errors.py @@ -83,7 +83,7 @@ def __call__(self, x): # this causes an error when using lazy_init. k = self.param("kernel", lambda _: x) return x * k - Foo().lazy_init(random.PRNGKey(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) + Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) """ def __init__(self, partial_val): @@ -131,7 +131,7 @@ def __call__(self, x): So, ``Foo`` is initialized as follows:: - init_rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} + init_rngs = {'params': random.key(0), 'dropout': random.key(1)} variables = Foo().init(init_rngs, init_inputs) If a Module only requires an rng for ``params``, you can use:: @@ -144,7 +144,7 @@ def __call__(self, x): When applying ``Foo``, only the rng for ``dropout`` is needed, because ``params`` is only used for initializing the Module parameters:: - Foo().apply(variables, inputs, rngs={'dropout': random.PRNGKey(2)}) + Foo().apply(variables, inputs, rngs={'dropout': random.key(2)}) If a Module only requires an rng for ``params``, you don't have to provide rngs for apply at all:: @@ -209,7 +209,7 @@ def __call__(self, inputs, embed_name='embedding'): return embedding[inputs] model = Embed(4, 8) - variables = model.init(random.PRNGKey(0), jnp.ones((5, 5, 1))) + variables = model.init(random.key(0), jnp.ones((5, 5, 1))) _ = model.apply(variables, jnp.ones((5, 5, 1)), 'embed') """ @@ -264,7 +264,7 @@ def __call__(self, x): (((x.ndim - 1,), (0,)), ((), ()))) return y - variables = NoBiasDense().init(random.PRNGKey(0), jnp.ones((5, 5, 1))) + variables = NoBiasDense().init(random.key(0), jnp.ones((5, 5, 1))) _ = NoBiasDense().apply(variables, jnp.ones((5, 5))) """ @@ -450,7 +450,7 @@ def setup(self): def __call__(self, x): conv = nn.Conv(features=3, kernel_size=3) - Foo().init(random.PRNGKey(0), jnp.zeros((1,))) + Foo().init(random.key(0), jnp.zeros((1,))) Note that this error is also thrown if you partially defined a Module inside setup:: @@ -463,7 +463,7 @@ def __call__(self, x): x = self.conv(kernel_size=4)(x) return x - Foo().init(random.PRNGKey(0), jnp.zeros((1,))) + Foo().init(random.key(0), jnp.zeros((1,))) In this case, ``self.conv(kernel_size=4)`` is called from ``__call__``, which is disallowed because it's neither within ``setup`` nor a method wrapped in @@ -491,7 +491,7 @@ def setup(self): def __call__(self, x): return nn.Dense(self.features)(x) - variables = SomeModule().init(random.PRNGKey(0), jnp.ones((1, ))) + variables = SomeModule().init(random.key(0), jnp.ones((1, ))) Instead, these attributes should be set when initializing the Module:: @@ -502,7 +502,7 @@ class Foo(nn.Module): def __call__(self, x): return nn.Dense(self.features)(x) - variables = SomeModule(features=3).init(random.PRNGKey(0), jnp.ones((1, ))) + variables = SomeModule(features=3).init(random.key(0), jnp.ones((1, ))) TODO(marcvanzee): Link to a design note explaining why it's necessary for modules to stay frozen (otherwise we can't safely clone them, which we use for @@ -529,7 +529,7 @@ def __call__(self, x, num_features=10): x = nn.Dense(self.num_features)(x) return x - s = SomeModule().init(random.PRNGKey(0), jnp.ones((5, 5))) + s = SomeModule().init(random.key(0), jnp.ones((5, 5))) Similarly, the error is raised when trying to modify a submodule's attributes after constructing it, even if this is done in the ``setup()`` method of the @@ -619,7 +619,7 @@ class CallCompactUnboundModuleError(FlaxError): :meth:`Module.init() ` to get initial variables):: from jax import random - variables = test_dense.init(random.PRNGKey(0), jnp.ones((5,5))) + variables = test_dense.init(random.key(0), jnp.ones((5,5))) y = test_dense.apply(variables, jnp.ones((5,5))) """ @@ -699,8 +699,8 @@ class B(nn.Module): def __call__(self, x): return x - k = random.PRNGKey(0) - x = random.uniform(random.PRNGKey(1), (2,)) + k = random.key(0) + x = random.uniform(random.key(1), (2,)) B.init(k, x) # B is module class, not B() a module instance B.apply(vs, x) # similar issue with apply called on class instead of instance. @@ -731,7 +731,7 @@ def __call__(self, input): return input + 3 r = A(x=3) - r.init(jax.random.PRNGKey(2), jnp.ones(3)) + r.init(jax.random.key(2), jnp.ones(3)) """ def __init__(self): @@ -753,7 +753,7 @@ def __call__(self, x): return self.prop foo = Foo() - variables = foo.init(jax.random.PRNGKey(0), jnp.ones(shape=(1, 8))) + variables = foo.init(jax.random.key(0), jnp.ones(shape=(1, 8))) """ def __init__(self): diff --git a/flax/linen/initializers.py b/flax/linen/initializers.py index 8579a9c62e..1e4977e380 100644 --- a/flax/linen/initializers.py +++ b/flax/linen/initializers.py @@ -44,7 +44,7 @@ def zeros_init() -> Initializer: >>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import zeros_init >>> zeros_initializer = zeros_init() - >>> zeros_initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) + >>> zeros_initializer(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32) """ @@ -57,7 +57,7 @@ def ones_init() -> Initializer: >>> import jax, jax.numpy as jnp >>> from flax.linen.initializers import ones_init >>> ones_initializer = ones_init() - >>> ones_initializer(jax.random.PRNGKey(42), (3, 2), jnp.float32) + >>> ones_initializer(jax.random.key(42), (3, 2), jnp.float32) Array([[1., 1.], [1., 1.], [1., 1.]], dtype=float32) diff --git a/flax/linen/module.py b/flax/linen/module.py index 2f0b7760a5..3d45e30da2 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -89,7 +89,7 @@ TestScope = type( 'TestScope', (Scope,), - {'make_rng': lambda self, name: jax.random.PRNGKey(0)}, + {'make_rng': lambda self, name: jax.random.key(0)}, ) @@ -1767,7 +1767,7 @@ def __call__(self, x): x = jnp.ones((16, 9)) ae = AutoEncoder() - variables = ae.init(jax.random.PRNGKey(0), x) + variables = ae.init(jax.random.key(0), x) model = ae.bind(variables) z = model.encoder(x) x_reconstructed = model.decoder(z) @@ -1811,7 +1811,7 @@ def __call__(self, x): return self.decoder(self.encoder(x)) module = AutoEncoder() - variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 784))) + variables = module.init(jax.random.key(0), jnp.ones((1, 784))) ... # Extract the Encoder sub-Module and its variables encoder, encoder_vars = module.bind(variables).encoder.unbind() @@ -2027,7 +2027,7 @@ def init( ... return nn.Dense(1)(x) ... >>> module = Foo() - >>> key = jax.random.PRNGKey(0) + >>> key = jax.random.key(0) >>> variables = module.init(key, jnp.empty((1, 7)), train=False) If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` @@ -2051,8 +2051,8 @@ def init( ... return nn.Dense(1)(x) ... >>> module = Foo() - >>> rngs = {'params': jax.random.PRNGKey(0), - ... 'noise': jax.random.PRNGKey(1)} + >>> rngs = {'params': jax.random.key(0), + ... 'noise': jax.random.key(1)} >>> variables = module.init(rngs, jnp.empty((1, 7)), train=False) Jitting `init` initializes a model lazily using only the shapes of the @@ -2061,7 +2061,7 @@ def init( >>> module = nn.Dense(1) >>> init_jit = jax.jit(module.init) - >>> variables = init_jit(jax.random.PRNGKey(0), jnp.empty((1, 7))) + >>> variables = init_jit(jax.random.key(0), jnp.empty((1, 7))) ``init`` is a light wrapper over ``apply``, so other ``apply`` arguments like ``method``, ``mutable``, and ``capture_intermediates`` are also @@ -2234,7 +2234,7 @@ def __call__(self, x): x = jnp.ones((16, 9)) model = Foo() - variables = model.init(jax.random.PRNGKey(0), x) + variables = model.init(jax.random.key(0), x) y, state = model.apply(variables, x, mutable=['intermediates']) print(state['intermediates']) # {'h': (...,)} @@ -2255,7 +2255,7 @@ def __call__(self, x): return x model = Foo2() - variables = model.init(jax.random.PRNGKey(0), x) + variables = model.init(jax.random.key(0), x) y, state = model.apply( variables, jnp.ones((1, 1)), mutable=['intermediates']) print(state['intermediates']) # ==> {'h': [[3.]]} @@ -2324,7 +2324,7 @@ def loss(params, perturbations, inputs, targets): x = jnp.ones((2, 9)) y = jnp.ones((2, 2)) model = Foo() - variables = model.init(jax.random.PRNGKey(0), x) + variables = model.init(jax.random.key(0), x) intm_grads = jax.grad(loss, argnums=1)( variables['params'], variables['perturbations'], x, y) print(intm_grads['dense3']) # ==> [[-1.456924 -0.44332537 0.02422847] @@ -2389,7 +2389,7 @@ def __call__(self, x): x = jnp.ones((16, 9)) - print(Foo().tabulate(jax.random.PRNGKey(0), x)) + print(Foo().tabulate(jax.random.key(0), x)) This gives the following output:: diff --git a/flax/linen/normalization.py b/flax/linen/normalization.py index 2be694634a..81c1cf595e 100644 --- a/flax/linen/normalization.py +++ b/flax/linen/normalization.py @@ -425,9 +425,9 @@ class RMSNorm(Module): >>> import jax >>> import flax.linen as nn ... - >>> x = jax.random.uniform(jax.random.PRNGKey(0), (2, 3)) + >>> x = jax.random.uniform(jax.random.key(0), (2, 3)) >>> layer = nn.RMSNorm() - >>> variables = layer.init(jax.random.PRNGKey(1), x) + >>> variables = layer.init(jax.random.key(1), x) >>> y = layer.apply(variables, x) Attributes: diff --git a/flax/linen/recurrent.py b/flax/linen/recurrent.py index 259127380b..593546f0a8 100644 --- a/flax/linen/recurrent.py +++ b/flax/linen/recurrent.py @@ -632,7 +632,7 @@ class RNN(Module): ... >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64)) - >>> variables = lstm.init(jax.random.PRNGKey(0), x) + >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (batch, time, cell_size) (10, 50, 64) @@ -645,7 +645,7 @@ class RNN(Module): >>> x = jnp.ones((10, 50, 32, 32, 3)) # (batch, time, height, width, features) >>> conv_lstm = nn.RNN(nn.ConvLSTMCell(64, kernel_size=(3, 3))) - >>> y, variables = conv_lstm.init_with_output(jax.random.PRNGKey(0), x) + >>> y, variables = conv_lstm.init_with_output(jax.random.key(0), x) >>> y.shape # (batch, time, height, width, features) (10, 50, 32, 32, 64) @@ -655,7 +655,7 @@ class RNN(Module): >>> x = jnp.ones((50, 10, 32)) # (time, batch, features) >>> lstm = nn.RNN(nn.LSTMCell(64), time_major=True) - >>> variables = lstm.init(jax.random.PRNGKey(0), x) + >>> variables = lstm.init(jax.random.key(0), x) >>> y = lstm.apply(variables, x) >>> y.shape # (time, batch, cell_size) (50, 10, 64) @@ -665,7 +665,7 @@ class RNN(Module): >>> x = jnp.ones((10, 50, 32)) # (batch, time, features) >>> lstm = nn.RNN(nn.LSTMCell(64), return_carry=True) - >>> variables = lstm.init(jax.random.PRNGKey(0), x) + >>> variables = lstm.init(jax.random.key(0), x) >>> carry, y = lstm.apply(variables, x) >>> jax.tree_map(jnp.shape, carry) # ((batch, cell_size), (batch, cell_size)) ((10, 64), (10, 64)) @@ -768,7 +768,7 @@ def __call__( initial_carry: the initial carry, if not provided it will be initialized using the cell's :meth:`RNNCellBase.initialize_carry` method. init_key: a PRNG key used to initialize the carry, if not provided - ``jax.random.PRNGKey(0)`` will be used. Most cells will ignore this + ``jax.random.key(0)`` will be used. Most cells will ignore this argument. seq_lengths: an optional integer array of shape ``(*batch)`` indicating the length of each sequence, elements whose index in the time dimension @@ -828,7 +828,7 @@ def __call__( carry: Carry if initial_carry is None: if init_key is None: - init_key = random.PRNGKey(0) + init_key = random.key(0) input_shape = inputs.shape[:time_axis] + inputs.shape[time_axis + 1 :] carry = self.cell.initialize_carry(init_key, input_shape) diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 67c195c337..9ebaf4256a 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -188,7 +188,7 @@ def __call__(self, x): return nn.Dense(2)(h) x = jnp.ones((16, 9)) - tabulate_fn = nn.tabulate(Foo(), jax.random.PRNGKey(0)) + tabulate_fn = nn.tabulate(Foo(), jax.random.key(0)) print(tabulate_fn(x)) diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 7f4cf1c775..adaf7822fc 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -673,7 +673,7 @@ def checkpoint( ... return x ... >>> model = CheckpointedMLP() - >>> variables = model.init(jax.random.PRNGKey(0), jnp.ones((1, 16))) + >>> variables = model.init(jax.random.key(0), jnp.ones((1, 16))) This function is aliased to ``remat`` just like ``jax.remat``. @@ -855,13 +855,13 @@ def scan( ... ... lstm = ScanLSTM(self.features) ... input_shape = x[:, 0].shape - ... carry = lstm.initialize_carry(jax.random.PRNGKey(0), input_shape) + ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) ... carry, x = lstm(carry, x) ... return x ... >>> x = jnp.ones((4, 12, 7)) >>> module = LSTM(features=32) - >>> y, variables = module.init_with_output(jax.random.PRNGKey(0), x) + >>> y, variables = module.init_with_output(jax.random.key(0), x) Note that when providing a function to ``nn.scan``, the scanning happens over all arguments starting from the third argument, as specified by ``in_axes``. @@ -883,12 +883,12 @@ def scan( ... ... input_shape = x[:, 0].shape ... carry = cell.initialize_carry( - ... jax.random.PRNGKey(0), input_shape) + ... jax.random.key(0), input_shape) ... carry, x = scan(cell, carry, x) ... return x ... >>> module = LSTM(features=32) - >>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((4, 12, 7))) + >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7))) You can also use ``scan`` to reduce the compilation time of your JAX program by merging multiple layers into a single scan loop, you can do this when @@ -915,7 +915,7 @@ def scan( ... return x ... >>> model = ResidualMLP(n_layers=4) - >>> variables = model.init(jax.random.PRNGKey(42), jnp.ones((1, 2))) + >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2))) To reduce both compilation and memory usage, you can use :func:`remat_scan` which will in addition checkpoint each layer in the scan loop. @@ -1016,7 +1016,7 @@ def map_variables( ... return self.dense(x) ... >>> module = CausalDense(features=5) - >>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 5))) + >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5))) Args: target: the module or function to be transformed. @@ -1262,7 +1262,7 @@ def body_fn(mdl, c): return nn.while_loop(cond_fn, body_fn, self, c, carry_variables='state') - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) intial_vars = WhileLoopExample().init(k, x) result, state = WhileLoopExample().apply(intial_vars, x, mutable=['state']) @@ -1516,7 +1516,7 @@ def bwd(vjp_fn, y_t): return sign_grad(nn.Dense(1), x).reshape(()) x = jnp.ones((2,)) - variables = Foo().init(random.PRNGKey(0), x) + variables = Foo().init(random.key(0), x) grad = jax.grad(Foo().apply)(variables, x) Args: diff --git a/pyproject.toml b/pyproject.toml index b299351035..f55f23d343 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ ] dependencies = [ "numpy>=1.12", - "jax>=0.4.2", + "jax>=0.4.11", "msgpack", "optax", "orbax-checkpoint", diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index c433c2276e..74cc245c7f 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -39,7 +39,7 @@ def g(scopes, _): lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,))) - init(f)(random.PRNGKey(0)) + init(f)(random.key(0)) def test_undefined_param(self): def f(scope): @@ -70,8 +70,8 @@ def f(scope, x): return nn.dense(scope, x, 1) x = np.ones((3, 2)) - _, params = init(f)(random.PRNGKey(0), x) - init(f)(random.PRNGKey(0), x) + _, params = init(f)(random.key(0), x) + init(f)(random.key(0), x) self.assertEqual(compiles, 1) apply(f)(params, x) self.assertEqual(compiles, 2) # apply should cause a compile @@ -95,7 +95,7 @@ def f(scope, x, y): x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) - _, params = init(f)(random.PRNGKey(0), x, y) + _, params = init(f)(random.key(0), x, y) params_grad, x_grad, y_grad = apply(f)(params, x, y) self.assertEqual( params_grad, @@ -122,13 +122,13 @@ def f(scope, x): return out_t x = jnp.ones((3,)) - _, params = init(f)(random.PRNGKey(0), x) + _, params = init(f)(random.key(0), x) y_t = apply(f)(params, x) np.testing.assert_allclose(y_t, jnp.ones_like(x)) def test_while_loop(self): def f(scope, x): - key_zero = random.PRNGKey(0) + key_zero = random.key(0) key_zero = jnp.broadcast_to(key_zero, (2, *key_zero.shape)) scope.param('inc', lambda _: 1) scope.put_variable('state', 'acc', 0) @@ -168,7 +168,7 @@ def body_fn(scope, c): x = 2 c, vars = apply(f, mutable=True)( - {}, x, rngs={'params': random.PRNGKey(1), 'loop': random.PRNGKey(2)} + {}, x, rngs={'params': random.key(1), 'loop': random.key(2)} ) self.assertEqual(vars['state']['acc'], x) self.assertEqual(c, 2 * x) @@ -197,7 +197,7 @@ def false_fn(scope, x): return lift.cond(pred, true_fn, false_fn, scope, x) x = jnp.ones((1, 3)) - y1, vars = init(f)(random.PRNGKey(0), x, True) + y1, vars = init(f)(random.key(0), x, True) self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 0}) y2, vars = apply(f, mutable='state')(vars, x, False) self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1}) @@ -224,7 +224,7 @@ def c_fn(scope, x): return lift.switch(index, [a_fn, b_fn, c_fn], scope, x) x = jnp.ones((1, 3)) - y1, vars = init(f)(random.PRNGKey(0), x, 0) + y1, vars = init(f)(random.key(0), x, 0) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0}) y2, updates = apply(f, mutable='state')(vars, x, 1) vars = copy(vars, updates) @@ -252,7 +252,7 @@ def test(scope, x): self.assertEqual(val0, val1) return x - init(test)(random.PRNGKey(0), 1.0) + init(test)(random.key(0), 1.0) if __name__ == '__main__': diff --git a/tests/core/core_meta_test.py b/tests/core/core_meta_test.py index 40db471f8e..a1f3fac6ba 100644 --- a/tests/core/core_meta_test.py +++ b/tests/core/core_meta_test.py @@ -45,7 +45,7 @@ def g(scope, x): metadata_params={meta.PARTITION_NAME: 'batch'}, )(scope, xs) - _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) + _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['kernel'].names, ('batch', 'in', 'out') ) @@ -78,7 +78,7 @@ def g(scope, x): metadata_params={meta.PARTITION_NAME: 'batch'}, )(scope, xs) - _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) + _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['kernel'].names, ('batch', 'in', 'out') ) @@ -101,7 +101,7 @@ def g(scope, x): metadata_params={}, )(scope, xs) - init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) + init(f)(random.key(0), jnp.zeros((8, 3))) def test_unbox(self): xs = { @@ -141,7 +141,7 @@ def body(scope, x): )(scope, x) return c - _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) + _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) boxed_shapes = jax.tree_map(jnp.shape, variables['params']) self.assertEqual( boxed_shapes, @@ -209,7 +209,7 @@ def f(scope, x): @jax.jit def create_state(): - y, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 4))) + y, variables = init(f)(random.key(0), jnp.zeros((8, 4))) spec = meta.get_partition_spec(variables) shardings = jax.tree_map(lambda s: sharding.NamedSharding(mesh, s), spec) variables = jax.lax.with_sharding_constraint(variables, shardings) diff --git a/tests/core/core_scope_test.py b/tests/core/core_scope_test.py index cfd514bd02..86634a1c8a 100644 --- a/tests/core/core_scope_test.py +++ b/tests/core/core_scope_test.py @@ -37,10 +37,10 @@ def f(scope): self.assertFalse(scope.has_rng('dropout')) rng = scope.make_rng('params') self.assertTrue( - np.all(rng == LazyRng.create(random.PRNGKey(0), 1).as_jax_rng()) + np.all(rng == LazyRng.create(random.key(0), 1).as_jax_rng()) ) - init(f)(random.PRNGKey(0)) + init(f)(random.key(0)) def test_in_filter(self): filter_true = lambda x, y: self.assertTrue(scope.in_filter(x, y)) @@ -160,7 +160,7 @@ def f(scope): r' immutable.' ) with self.assertRaisesRegex(errors.ModifyScopeVariableError, msg): - init(f, mutable='params')(random.PRNGKey(0)) + init(f, mutable='params')(random.key(0)) def test_undefined_param(self): def f(scope): @@ -182,7 +182,7 @@ def test_rngs_check_w_frozen_dict(self): def f(scope, x): return x - _ = apply(f)({}, np.array([0.0]), rngs=freeze({'a': random.PRNGKey(0)})) + _ = apply(f)({}, np.array([0.0]), rngs=freeze({'a': random.key(0)})) def test_rng_check_w_old_and_new_keys(self): # random.key always returns a new-style typed PRNG key. @@ -210,10 +210,10 @@ def g(scope): scope.child(g)() - jax.jit(init(f))(random.PRNGKey(0)) + jax.jit(init(f))(random.key(0)) def test_rng_counter_reuse(self): - root = Scope({}, {'dropout': random.PRNGKey(0)}) + root = Scope({}, {'dropout': random.key(0)}) def f(scope): return scope.make_rng('dropout') @@ -267,7 +267,7 @@ def f(scope, x): init_fn = lazy_init(f) # provide a massive input message which would OOM if any compute ops were actually executed variables = init_fn( - random.PRNGKey(0), + random.key(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32), ) self.assertEqual(variables['params']['kernel'].shape, (128, 128)) @@ -280,12 +280,12 @@ def f(scope, x): init_fn = lazy_init(f) with self.assertRaises(errors.LazyInitError): - init_fn(random.PRNGKey(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) + init_fn(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) @temp_flip_flag('fix_rng_separator', True) def test_fold_in_static_seperator(self): - x = LazyRng(random.PRNGKey(0), ('ab', 'c')) - y = LazyRng(random.PRNGKey(0), ('a', 'bc')) + x = LazyRng(random.key(0), ('ab', 'c')) + y = LazyRng(random.key(0), ('a', 'bc')) self.assertFalse(np.all(x.as_jax_rng() == y.as_jax_rng())) diff --git a/tests/core/design/core_attention_test.py b/tests/core/design/core_attention_test.py index 2ddb468890..3ab4005509 100644 --- a/tests/core/design/core_attention_test.py +++ b/tests/core/design/core_attention_test.py @@ -156,7 +156,7 @@ def test_attention(self): attn_fn=with_dropout(softmax_attn, 0.1, deterministic=False), ) - rngs = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)} + rngs = {'params': random.key(0), 'dropout': random.key(1)} y, variables = jax.jit(init(model))(rngs, inputs, inputs) variable_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(y.shape, (2, 7, 16)) diff --git a/tests/core/design/core_auto_encoder_test.py b/tests/core/design/core_auto_encoder_test.py index 785caf626b..7eff74f4a5 100644 --- a/tests/core/design/core_auto_encoder_test.py +++ b/tests/core/design/core_auto_encoder_test.py @@ -104,7 +104,7 @@ class AutoEncoderTest(absltest.TestCase): def test_auto_encoder_hp_struct(self): ae = AutoEncoder(latents=2, features=4, hidden=3) x = jnp.ones((1, 4)) - x_r, variables = init(ae)(random.PRNGKey(0), x) + x_r, variables = init(ae)(random.key(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) @@ -129,7 +129,7 @@ def test_auto_encoder_with_scope(self): ) x = jnp.ones((1, 4)) - x_r, variables = init(ae)(random.PRNGKey(0), x) + x_r, variables = init(ae)(random.key(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) @@ -154,7 +154,7 @@ def test_auto_encoder_bind_method(self): )(x) x = jnp.ones((1, 4)) - x_r, variables = init(ae)(random.PRNGKey(0), x) + x_r, variables = init(ae)(random.key(0), x) self.assertEqual(x.shape, x_r.shape) variable_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/core/design/core_big_resnets_test.py b/tests/core/design/core_big_resnets_test.py index d95cf16fc1..01d38362bf 100644 --- a/tests/core/design/core_big_resnets_test.py +++ b/tests/core/design/core_big_resnets_test.py @@ -68,8 +68,8 @@ def body_fn(scope, x): class BigResnetTest(absltest.TestCase): def test_big_resnet(self): - x = random.normal(random.PRNGKey(0), (1, 8, 8, 8)) - y, variables = init(big_resnet)(random.PRNGKey(1), x) + x = random.normal(random.key(0), (1, 8, 8, 8)) + y, variables = init(big_resnet)(random.key(1), x) self.assertEqual(y.shape, (1, 8, 8, 8)) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/core/design/core_custom_vjp_test.py b/tests/core/design/core_custom_vjp_test.py index aa7516f456..c4354a3e40 100644 --- a/tests/core/design/core_custom_vjp_test.py +++ b/tests/core/design/core_custom_vjp_test.py @@ -60,8 +60,8 @@ def bwd(features, res, y_t): class CustomVJPTest(absltest.TestCase): def test_custom_vjp(self): - x = random.normal(random.PRNGKey(0), (1, 4)) - y, variables = init(mlp_custom_grad)(random.PRNGKey(1), x) + x = random.normal(random.key(0), (1, 4)) + y, variables = init(mlp_custom_grad)(random.key(1), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) diff --git a/tests/core/design/core_dense_test.py b/tests/core/design/core_dense_test.py index a90d93cf3b..ab5b15e723 100644 --- a/tests/core/design/core_dense_test.py +++ b/tests/core/design/core_dense_test.py @@ -115,7 +115,7 @@ class DenseTest(absltest.TestCase): def test_dense(self): model = Dense(features=4) x = jnp.ones((1, 3)) - y, variables = init(model)(random.PRNGKey(0), x) + y, variables = init(model)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) @@ -130,7 +130,7 @@ def test_dense(self): def test_explicit_dense(self): x = jnp.ones((1, 3)) - y, variables = init(explicit_mlp)(random.PRNGKey(0), x) + y, variables = init(explicit_mlp)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) @@ -145,7 +145,7 @@ def test_explicit_dense(self): def test_explicit_dense(self): x = jnp.ones((1, 4)) - y, variables = init(explicit_mlp)(random.PRNGKey(0), x) + y, variables = init(explicit_mlp)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) @@ -160,7 +160,7 @@ def test_explicit_dense(self): def test_semi_explicit_dense(self): x = jnp.ones((1, 4)) - y, variables = init(semi_explicit_mlp)(random.PRNGKey(0), x) + y, variables = init(semi_explicit_mlp)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) diff --git a/tests/core/design/core_flow_test.py b/tests/core/design/core_flow_test.py index 61feb66500..872674745b 100644 --- a/tests/core/design/core_flow_test.py +++ b/tests/core/design/core_flow_test.py @@ -67,7 +67,7 @@ class FlowTest(absltest.TestCase): def test_flow(self): x = jnp.ones((1, 3)) flow = StackFlow((DenseFlow(),) * 3) - y, variables = init(flow.forward)(random.PRNGKey(0), x) + y, variables = init(flow.forward)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) ) diff --git a/tests/core/design/core_resnet_test.py b/tests/core/design/core_resnet_test.py index ac7628bd7e..85f30f0a1c 100644 --- a/tests/core/design/core_resnet_test.py +++ b/tests/core/design/core_resnet_test.py @@ -87,9 +87,9 @@ class ResNetTest(absltest.TestCase): def test_resnet(self): block_sizes = (2, 2) - x = random.normal(random.PRNGKey(0), (1, 64, 64, 3)) + x = random.normal(random.key(0), (1, 64, 64, 3)) y, variables = init(resnet)( - random.PRNGKey(1), x, block_sizes=block_sizes, features=16 + random.key(1), x, block_sizes=block_sizes, features=16 ) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/core/design/core_scan_test.py b/tests/core/design/core_scan_test.py index 84d22d4344..fc8dc3c514 100644 --- a/tests/core/design/core_scan_test.py +++ b/tests/core/design/core_scan_test.py @@ -51,9 +51,9 @@ def body_fn(scope, c, x): class ScanTest(absltest.TestCase): def test_scan_unshared_params(self): - x = random.normal(random.PRNGKey(0), (1, 4)) + x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) - y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=False) + y, variables = init(mlp_scan)(random.key(1), x, share_params=False) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) @@ -71,9 +71,9 @@ def test_scan_unshared_params(self): self.assertFalse(jnp.allclose(k1, k2)) def test_scan_shared_params(self): - x = random.normal(random.PRNGKey(0), (1, 4)) + x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) - y, variables = init(mlp_scan)(random.PRNGKey(1), x, share_params=True) + y, variables = init(mlp_scan)(random.key(1), x, share_params=True) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/core/design/core_tied_autoencoder_test.py b/tests/core/design/core_tied_autoencoder_test.py index 2b20844dab..2cbae4e78d 100644 --- a/tests/core/design/core_tied_autoencoder_test.py +++ b/tests/core/design/core_tied_autoencoder_test.py @@ -53,7 +53,7 @@ class TiedAutoEncoderTest(absltest.TestCase): def test_tied_auto_encoder(self): ae = TiedAutoEncoder(latents=2, features=4) x = jnp.ones((1, ae.features)) - x_r, variables = init(ae)(random.PRNGKey(0), x) + x_r, variables = init(ae)(random.key(0), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) @@ -69,7 +69,7 @@ def test_tied_auto_encoder(self): def test_init_from_decoder(self): ae = TiedAutoEncoder(latents=2, features=4) z = jnp.ones((1, ae.latents)) - x_r, variables = init(ae.decode)(random.PRNGKey(0), z) + x_r, variables = init(ae.decode)(random.key(0), z) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/core/design/core_vmap_test.py b/tests/core/design/core_vmap_test.py index 062fc2c1d6..4626f6db80 100644 --- a/tests/core/design/core_vmap_test.py +++ b/tests/core/design/core_vmap_test.py @@ -56,10 +56,10 @@ def mlp_vmap( class VMapTest(absltest.TestCase): def test_vmap_shared(self): - x = random.normal(random.PRNGKey(0), (1, 4)) + x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) - y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=True) + y, variables = init(mlp_vmap)(random.key(1), x, share_params=True) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) @@ -75,10 +75,10 @@ def test_vmap_shared(self): self.assertTrue(jnp.allclose(y[0], y[1])) def test_vmap_unshared(self): - x = random.normal(random.PRNGKey(0), (1, 4)) + x = random.normal(random.key(0), (1, 4)) x = jnp.concatenate([x, x], 0) - y, variables = init(mlp_vmap)(random.PRNGKey(1), x, share_params=False) + y, variables = init(mlp_vmap)(random.key(1), x, share_params=False) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/core/design/core_weight_std_test.py b/tests/core/design/core_weight_std_test.py index d4028c7b8d..64a2a8a431 100644 --- a/tests/core/design/core_weight_std_test.py +++ b/tests/core/design/core_weight_std_test.py @@ -55,13 +55,13 @@ class WeightStdTest(absltest.TestCase): def test_weight_std(self): x = random.normal( - random.PRNGKey(0), + random.key(0), ( 1, 4, ), ) - y, variables = init(mlp)(random.PRNGKey(1), x) + y, variables = init(mlp)(random.key(1), x) param_shapes = unfreeze( jax.tree_util.tree_map(jnp.shape, variables['params']) diff --git a/tests/cursor_test.py b/tests/cursor_test.py index 72c6672fae..268d267a66 100644 --- a/tests/cursor_test.py +++ b/tests/cursor_test.py @@ -170,7 +170,7 @@ def __call__(self, x): for freeze_wrap in (lambda x: x, freeze): params = freeze_wrap( - Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))['params'] + Model().init(jax.random.key(0), jnp.empty((1, 2)))['params'] ) c = cursor(params) @@ -193,7 +193,7 @@ def __call__(self, x): lambda x, y: (x == y).all(), params, freeze_wrap( - Model().init(jax.random.PRNGKey(0), jnp.empty((1, 2)))[ + Model().init(jax.random.key(0), jnp.empty((1, 2)))[ 'params' ] ), diff --git a/tests/linen/initializers_test.py b/tests/linen/initializers_test.py index 26274dd454..ec4ecb4872 100644 --- a/tests/linen/initializers_test.py +++ b/tests/linen/initializers_test.py @@ -45,7 +45,7 @@ class InitializersTest(parameterized.TestCase): }, ) def test_call_builder(self, builder_fn, params_shape, expected_params): - params = builder_fn()(random.PRNGKey(42), params_shape, jnp.float32) + params = builder_fn()(random.key(42), params_shape, jnp.float32) np.testing.assert_allclose(params, expected_params) @parameterized.parameters( @@ -60,7 +60,7 @@ def test_call_builder(self, builder_fn, params_shape, expected_params): ) def test_kernel_builder(self, builder_fn, expected_params): layer = nn.Dense(5, kernel_init=builder_fn()) - params = layer.init(random.PRNGKey(42), jnp.empty((3, 2)))['params'] + params = layer.init(random.key(42), jnp.empty((3, 2)))['params'] np.testing.assert_allclose(params['kernel'], expected_params) diff --git a/tests/linen/linen_activation_test.py b/tests/linen/linen_activation_test.py index af922361c3..b15a2547e2 100644 --- a/tests/linen/linen_activation_test.py +++ b/tests/linen/linen_activation_test.py @@ -31,7 +31,7 @@ class ActivationTest(parameterized.TestCase): def test_prelu(self): - rng = random.PRNGKey(0) + rng = random.key(0) x = jnp.ones((4, 6, 5)) act = nn.PReLU() y, _ = act.init_with_output(rng, x) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 293db0e5a7..2556b292c2 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -37,7 +37,7 @@ class AttentionTest(parameterized.TestCase): def test_multihead_self_attention(self): - rng = random.PRNGKey(0) + rng = random.key(0) x = jnp.ones((4, 6, 5)) sa_module = nn.SelfAttention( num_heads=8, @@ -51,7 +51,7 @@ def test_multihead_self_attention(self): self.assertEqual(y.dtype, jnp.float32) def test_dtype_infer(self): - rng = random.PRNGKey(0) + rng = random.key(0) x = jnp.ones((4, 6, 5), jnp.complex64) sa_module = nn.SelfAttention( num_heads=8, @@ -65,7 +65,7 @@ def test_dtype_infer(self): self.assertEqual(y.dtype, jnp.complex64) def test_multihead_encoder_decoder_attention(self): - rng = random.PRNGKey(0) + rng = random.key(0) q = jnp.ones((4, 2, 3, 5)) kv = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( @@ -79,7 +79,7 @@ def test_multihead_encoder_decoder_attention(self): self.assertEqual(y.shape, q.shape) def test_multihead_self_attention_w_dropout(self): - rng = random.PRNGKey(0) + rng = random.key(0) x = jnp.ones((4, 2, 3, 5)) sa_module = nn.MultiHeadDotProductAttention( num_heads=8, @@ -95,7 +95,7 @@ def test_multihead_self_attention_w_dropout(self): self.assertEqual(y.shape, x.shape) def test_multihead_self_attention_w_dropout_disabled(self): - rng = random.PRNGKey(0) + rng = random.key(0) x = jnp.ones((4, 2, 3, 5)) sa_module0 = nn.MultiHeadDotProductAttention( num_heads=8, @@ -152,7 +152,7 @@ def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,) @@ -189,7 +189,7 @@ def body_fn(state, x): def test_autoregresive_receptive_field_1d(self): """Tests the autoregresive self-attention receptive field.""" - rng = random.PRNGKey(0) + rng = random.key(0) rng1, rng2 = random.split(rng, num=2) length = 10 diff --git a/tests/linen/linen_combinators_test.py b/tests/linen/linen_combinators_test.py index 685c01b434..383f8248ba 100644 --- a/tests/linen/linen_combinators_test.py +++ b/tests/linen/linen_combinators_test.py @@ -78,7 +78,7 @@ class SequentialTest(absltest.TestCase): def test_construction(self): sequential = nn.Sequential([nn.Dense(4), nn.Dense(2)]) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (3, 1, 5)) params = sequential.init(key2, x) output = sequential.apply(params, x) @@ -87,7 +87,7 @@ def test_construction(self): def test_fails_if_layers_empty(self): sequential = nn.Sequential([]) with self.assertRaisesRegex(ValueError, 'Empty Sequential module'): - sequential.init(random.PRNGKey(42), jnp.ones((3, 5))) + sequential.init(random.key(42), jnp.ones((3, 5))) def test_same_output_as_mlp(self): sequential = nn.Sequential([ @@ -97,7 +97,7 @@ def test_same_output_as_mlp(self): ]) mlp = MLP(layer_sizes=[4, 8, 2]) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (3, 5)) params_1 = sequential.init(key2, x) params_2 = mlp.init(key2, x) @@ -122,7 +122,7 @@ def test_same_output_as_mlp_with_activation(self): activation_final=nn.log_softmax, ) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) x = random.uniform(key1, (3, 5)) params_1 = sequential.init(key2, x) params_2 = mlp.init(key2, x) @@ -137,7 +137,7 @@ def test_tuple_output(self): AttentionTuple(), ]) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key1, (9, 5)) params_1 = sequential.init(key2, query, key_value) @@ -153,7 +153,7 @@ def test_dict_output(self): AttentionDict(), ]) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) query = random.uniform(key1, (3, 5)) key_value = random.uniform(key1, (9, 5)) params_1 = sequential.init(key2, query, key_value) diff --git a/tests/linen/linen_linear_test.py b/tests/linen/linen_linear_test.py index b496823cd2..cd393c844b 100644 --- a/tests/linen/linen_linear_test.py +++ b/tests/linen/linen_linear_test.py @@ -36,7 +36,7 @@ class LinearTest(parameterized.TestCase): def test_dense(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 3)) dense_module = nn.Dense( features=4, @@ -49,7 +49,7 @@ def test_dense(self): np.testing.assert_allclose(y, np.full((1, 4), 4.0)) def test_dense_extra_batch_dims(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 2, 3)) dense_module = nn.Dense( features=4, @@ -60,7 +60,7 @@ def test_dense_extra_batch_dims(self): np.testing.assert_allclose(y, np.full((1, 2, 4), 4.0)) def test_dense_no_bias(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 3)) dense_module = nn.Dense( features=4, @@ -71,24 +71,24 @@ def test_dense_no_bias(self): np.testing.assert_allclose(y, np.full((1, 4), 3.0)) def test_dense_is_dense_general(self): - x = jax.random.normal(random.PRNGKey(0), (5, 3)) + x = jax.random.normal(random.key(0), (5, 3)) dense_module = nn.Dense( features=4, use_bias=True, bias_init=initializers.normal(), ) - y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x) + y1, _ = dense_module.init_with_output(dict(params=random.key(1)), x) dg_module = nn.DenseGeneral( features=4, use_bias=True, bias_init=initializers.normal(), ) - y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x) + y2, _ = dg_module.init_with_output(dict(params=random.key(1)), x) np.testing.assert_allclose(y1, y2) def test_dense_general_batch_dim_raises(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 3, 2, 5)) with self.assertRaises(ValueError): dg_module = nn.DenseGeneral( @@ -100,7 +100,7 @@ def test_dense_general_batch_dim_raises(self): dg_module.init_with_output(rng, x) def test_dense_general_two_out(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 3)) dg_module = nn.DenseGeneral( features=(2, 2), @@ -111,7 +111,7 @@ def test_dense_general_two_out(self): np.testing.assert_allclose(y, np.full((1, 2, 2), 4.0)) def test_dense_general_two_in(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 2, 2)) dg_module = nn.DenseGeneral( features=3, @@ -123,7 +123,7 @@ def test_dense_general_two_in(self): np.testing.assert_allclose(y, np.full((1, 3), 5.0)) def test_dense_general_batch_dim(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((2, 1, 3, 5)) state = {'counter': 0.0} @@ -152,7 +152,7 @@ def _counter_init(rng, shape, dtype, state): ((-2, 3), (0,), 'bijk,bjklm->bilm'), ]) def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((16, 8, 9, 10)) dg_module = nn.DenseGeneral( @@ -169,7 +169,7 @@ def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): def test_complex_params_dense(self): dense = nn.Dense(features=2, param_dtype=jnp.complex64) x = jnp.ones((1, 2), jnp.float32) - variables = dense.init(random.PRNGKey(0), x) + variables = dense.init(random.key(0), x) self.assertEqual(variables['params']['kernel'].dtype, jnp.complex64) self.assertEqual(variables['params']['bias'].dtype, jnp.complex64) y = dense.apply(variables, x) @@ -178,7 +178,7 @@ def test_complex_params_dense(self): def test_complex_input_dense(self): dense = nn.Dense(features=2) x = jnp.ones((1, 2), jnp.complex64) - variables = dense.init(random.PRNGKey(0), x) + variables = dense.init(random.key(0), x) self.assertEqual(variables['params']['kernel'].dtype, jnp.float32) self.assertEqual(variables['params']['bias'].dtype, jnp.float32) y = dense.apply(variables, x) @@ -186,7 +186,7 @@ def test_complex_input_dense(self): @parameterized.product(use_bias=(True, False)) def test_conv(self, use_bias): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 3)) conv_module = nn.Conv( features=4, @@ -203,7 +203,7 @@ def test_conv(self, use_bias): @parameterized.product(use_bias=(True, False)) def test_multibatch_input_conv(self, use_bias): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((2, 5, 8, 3)) conv_module = nn.Conv( features=4, @@ -219,7 +219,7 @@ def test_multibatch_input_conv(self, use_bias): np.testing.assert_allclose(y, np.full((2, 5, 6, 4), expected)) def test_conv_local(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 2)) conv_module = nn.ConvLocal( features=4, @@ -233,7 +233,7 @@ def test_conv_local(self): np.testing.assert_allclose(y, np.full((1, 6, 4), 7.0)) def test_single_input_conv(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) conv_module = nn.Conv( features=4, @@ -247,7 +247,7 @@ def test_single_input_conv(self): np.testing.assert_allclose(y, np.full((6, 4), 10.0)) def test_single_input_masked_conv(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) m = jnp.tril(jnp.ones((3, 3, 4))) conv_module = nn.Conv( @@ -271,7 +271,7 @@ def test_single_input_masked_conv(self): np.testing.assert_allclose(y, expected) def test_single_input_conv_local(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((8, 2)) conv_module = nn.ConvLocal( features=4, @@ -285,7 +285,7 @@ def test_single_input_conv_local(self): np.testing.assert_allclose(y, np.full((6, 4), 7.0)) def test_group_conv(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 4)) conv_module = nn.Conv( features=4, @@ -323,7 +323,7 @@ def test_circular_conv_1d_constant( dimension) and have all elements equal to `n_input_features * kernel_lin_size`. """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_size, n_input_features)) conv_module = module( features=n_features, @@ -385,7 +385,7 @@ def test_circular_conv_2d_constant( dimension) and have all elements equal to `n_input_features * kernel_lin_size^2`. """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_x_size, input_y_size, n_input_features)) kernel_size = (kernel_lin_size, kernel_lin_size) conv_module = module( @@ -413,7 +413,7 @@ def test_circular_conv_2d_constant( def test_circular_conv_1d_custom(self): """Test 1d convolution with circular padding and a stride.""" - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) @@ -439,7 +439,7 @@ def test_circular_conv_local_1d_custom(self): """ Test 1d local convolution with circular padding and a stride """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array(((-1, 2, 3), (4, 5, 6))) @@ -462,7 +462,7 @@ def test_circular_conv_local_1d_custom(self): def test_circular_conv_1d_dilation(self): """Test 1d convolution with circular padding and kernel dilation.""" - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) @@ -494,7 +494,7 @@ def test_circular_conv_local_1d_dilation(self): """ Test 1d local convolution with circular padding and kernel dilation """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array( @@ -526,7 +526,7 @@ def test_circular_conv_local_1d_dilation(self): def test_circular_conv_2d_custom(self): """Test 2d convolution with circular padding on a 3x3 example.""" - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) x = np.expand_dims(x, (0, 3)) kernel = np.array(((0, 1, 0), (1, 2, 1), (0, 1, 0))) @@ -555,7 +555,7 @@ def test_circular_conv_local_2d_custom(self): """ Test 2d local convolution with circular padding on a 3x3 example """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9))) x = np.expand_dims(x, (0, 3)) kernel = np.array(( @@ -598,7 +598,7 @@ def test_circular_conv_local_2d_custom(self): np.testing.assert_allclose(y, correct_ans) def test_causal_conv1d(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 4)) conv_module = nn.Conv( features=4, @@ -625,7 +625,7 @@ def test_causal_conv1d(self): use_bias=(True, False), ) def test_conv_transpose(self, use_bias): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, @@ -657,7 +657,7 @@ def test_conv_transpose(self, use_bias): use_bias=(True, False), ) def test_multibatch_input_conv_transpose(self, use_bias): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((2, 5, 8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, @@ -688,7 +688,7 @@ def test_multibatch_input_conv_transpose(self, use_bias): np.testing.assert_allclose(y, correct_ans) def test_single_input_conv_transpose(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) conv_transpose_module = nn.ConvTranspose( features=4, @@ -714,7 +714,7 @@ def test_single_input_conv_transpose(self): np.testing.assert_allclose(y, correct_ans) def test_single_input_masked_conv_transpose(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((8, 3)) m = jnp.tril(jnp.ones((3, 3, 4))) conv_transpose_module = nn.ConvTranspose( @@ -758,7 +758,7 @@ def test_circular_conv_transpose_1d_constant( dimension) and have all elements equal to `n_input_features * kernel_lin_size`. """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_size, n_input_features)) conv_module = nn.ConvTranspose( features=n_features, @@ -802,7 +802,7 @@ def test_circular_conv_transpose_2d_constant( dimension) and have all elements equal to `n_input_features * kernel_lin_size^2`. """ - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((n_batch, input_x_size, input_y_size, n_input_features)) conv_module = nn.ConvTranspose( features=n_features, @@ -828,7 +828,7 @@ def test_circular_conv_transpose_2d_with_vmap(self): # this is ok sample_input = jnp.ones((1, 32, 2)) - out, vars = layer.init_with_output(jax.random.PRNGKey(0), sample_input) + out, vars = layer.init_with_output(jax.random.key(0), sample_input) self.assertEqual(out.shape, (1, 32, 5)) batch_input = jnp.ones((8, 32, 2)) @@ -840,7 +840,7 @@ def test_circular_conv_transpose_2d_with_vmap(self): def test_circular_conv_transpose_1d_custom(self): """Test 1d transposed convolution with circular padding and a stride.""" - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.arange(1, 6) x = np.expand_dims(x, (0, 2)) kernel = np.array((1, 2, 1)) @@ -880,7 +880,7 @@ def test_circular_conv_transpose_1d_custom(self): def test_circular_conv_transpose_2d_custom(self): """Test 2d transposed convolution with circular padding on a 3x3 example.""" - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.array(( (1, 2, 3), (4, 5, 6), @@ -911,7 +911,7 @@ def test_circular_conv_transpose_2d_custom(self): def test_circular_conv_transpose_2d_custom_bias(self): """Test 2d transposed convolution with circular padding on a 2x2 example with bias.""" - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = np.array(((1, 2), (3, 4))) x = np.expand_dims(x, (0, 3)) kernel = np.array(( @@ -940,7 +940,7 @@ def test_circular_conv_transpose_2d_custom_bias(self): @parameterized.product(use_bias=(True, False)) def test_transpose_kernel_conv_transpose(self, use_bias): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.ones((1, 15, 15, 3)) conv_module = nn.ConvTranspose( features=4, @@ -959,10 +959,10 @@ def test_int_kernel_size(self, module): conv = module(features=4, kernel_size=3) x = jnp.ones((8, 3)) with self.assertRaises(TypeError): - conv.init(random.PRNGKey(0), x) + conv.init(random.key(0), x) def test_embed(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.arange(4)[None] dummy_embedding = jnp.broadcast_to(jnp.arange(4)[..., None], (4, 3)).astype( jnp.float32 @@ -980,7 +980,7 @@ def test_embed(self): np.testing.assert_allclose(z, 3.0 * jnp.arange(4)) def test_embed_numpy(self): - rng = dict(params=random.PRNGKey(0)) + rng = dict(params=random.key(0)) x = jnp.arange(4)[None] dummy_embedding = np.broadcast_to(np.arange(4)[..., None], (4, 3)).astype( np.float32 @@ -1009,7 +1009,7 @@ def __call__(self, x): return nn.DenseGeneral(features=6, axis=1, name='dense')(x) x = jnp.ones((2, 4, 8)) - y, variables = Foo().init_with_output(random.PRNGKey(0), x) + y, variables = Foo().init_with_output(random.key(0), x) self.assertEqual( jax.tree_util.tree_map(jnp.shape, variables['params']), {'dense': {'kernel': (4, 6), 'bias': (6,)}}, @@ -1024,7 +1024,7 @@ def __call__(self, x): return nn.DenseGeneral(features=6, axis=(0, 1), name='dense')(x) x = jnp.ones((2, 4, 8)) - y, variables = Foo().init_with_output(random.PRNGKey(0), x) + y, variables = Foo().init_with_output(random.key(0), x) self.assertEqual( jax.tree_util.tree_map(jnp.shape, variables['params']), {'dense': {'kernel': (2, 4, 6), 'bias': (6,)}}, diff --git a/tests/linen/linen_meta_test.py b/tests/linen/linen_meta_test.py index b6c1a57f84..eaebda2b64 100644 --- a/tests/linen/linen_meta_test.py +++ b/tests/linen/linen_meta_test.py @@ -54,7 +54,7 @@ def __call__(self, xs): )(name='bar')(xs) m = Foo() - variables = m.init(random.PRNGKey(0), jnp.zeros((8, 3))) + variables = m.init(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['bar']['kernel'].names, ('batch', 'in', 'out') ) @@ -94,7 +94,7 @@ def __call__(self, xs): )(name='bar')(xs) m = Foo() - variables = m.init(random.PRNGKey(0), jnp.zeros((8, 3))) + variables = m.init(random.key(0), jnp.zeros((8, 3))) self.assertEqual( variables['params']['bar']['kernel'].names, ('batch', 'in', 'out') ) @@ -118,7 +118,7 @@ def __call__(self, xs): # variable_axes={'params': 0}, split_rngs={'params': True}, # metadata_params={nn.PARTITION_NAME: 'batch'})(scope, xs) - # _, variables = init(f)(random.PRNGKey(0), jnp.zeros((8, 3))) + # _, variables = init(f)(random.key(0), jnp.zeros((8, 3))) # self.assertEqual(variables['params']['kernel'].names, # ('batch', 'in', 'out')) @@ -159,9 +159,7 @@ def body(_, c): mesh = Mesh(devs, ['data', 'model']) model = Model() x = jnp.ones((8, 128)) - spec = nn.get_partition_spec( - jax.eval_shape(model.init, random.PRNGKey(0), x) - ) + spec = nn.get_partition_spec(jax.eval_shape(model.init, random.key(0), x)) self.assertEqual( spec, { @@ -181,16 +179,13 @@ def body(_, c): ) x_spec = PartitionSpec('data', 'model') f = lambda x: jax.sharding.NamedSharding(mesh, x) - if jax.config.jax_enable_custom_prng: - key_spec = PartitionSpec() - else: - key_spec = PartitionSpec(None) + key_spec = PartitionSpec() init_fn = jax.jit( model.init, in_shardings=jax.tree_map(f, (key_spec, x_spec)), out_shardings=jax.tree_map(f, spec), ) - variables = init_fn(random.PRNGKey(0), x) + variables = init_fn(random.key(0), x) apply_fn = jax.jit( model.apply, in_shardings=jax.tree_map(f, (spec, x_spec)), diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index f2b5da866a..a67d8c556a 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -104,7 +104,7 @@ def __call__(self): class ModuleTest(absltest.TestCase): def test_init_module(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) x = jnp.array([1.0]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = DummyModule(parent=scope)(x) @@ -126,7 +126,7 @@ def __call__(self, x): # provide a massive input message which would OOM if any compute ops were actually executed variables = Foo().lazy_init( - random.PRNGKey(0), + random.key(0), jax.ShapeDtypeStruct((1024 * 1024 * 1024, 128), jnp.float32), ) self.assertEqual(variables['params']['kernel'].shape, (128, 128)) @@ -140,12 +140,10 @@ def __call__(self, x): return x * k with self.assertRaises(errors.LazyInitError): - Foo().lazy_init( - random.PRNGKey(0), jax.ShapeDtypeStruct((8, 4), jnp.float32) - ) + Foo().lazy_init(random.key(0), jax.ShapeDtypeStruct((8, 4), jnp.float32)) def test_arg_module(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) x = jnp.ones((10,)) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dense(3, parent=scope)(x) @@ -155,7 +153,7 @@ def test_arg_module(self): self.assertEqual(params['kernel'].shape, (10, 3)) def test_util_fun(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class MLP(nn.Module): @@ -181,7 +179,7 @@ def _mydense(self, x): ) def test_nested_module_reuse(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class MLP(nn.Module): @@ -221,7 +219,7 @@ def __call__(self, x): ) def test_setup_dict_assignment(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class MLP(nn.Module): @@ -261,7 +259,7 @@ def __call__(self, x): foo = Foo() x = jnp.ones(shape=(1, 3)) - params = foo.init(random.PRNGKey(0), x)['params'] + params = foo.init(random.key(0), x)['params'] param_shape = jax.tree_util.tree_map(jnp.shape, params) self.assertEqual( param_shape, {'a_(1, 2)': {'kernel': (3, 2), 'bias': (2,)}} @@ -277,7 +275,7 @@ def setup(self): unused_clone = MLP(parent=scope).clone() def test_submodule_attr(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Inner(nn.Module): @@ -310,7 +308,7 @@ def __call__(self): self.assertEqual(40, scope.variables()['params']['inner']['x']) def test_param_in_setup(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class DummyModuleWithoutCompact(nn.Module): xshape: Tuple[int, ...] @@ -331,7 +329,7 @@ def __call__(self, x): self.assertEqual(params, {'bias': jnp.array([1.0])}) def test_init_outside_setup_without_compact(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class DummyModuleWithoutCompact(nn.Module): @@ -345,7 +343,7 @@ def __call__(self, x): unused_y = DummyModuleWithoutCompact(parent=scope)(x) def test_init_outside_call(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): @@ -364,7 +362,7 @@ def foo(self, x): unused_y = Dummy(parent=scope).foo(x) def test_setup_call_var_collision(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -384,7 +382,7 @@ def __call__(self, x): unused_y = Dummy(x.shape, parent=scope)(x) def test_call_var_collision(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -402,7 +400,7 @@ def __call__(self, x): unused_y = Dummy(x.shape, parent=scope)(x) def test_setup_var_collision(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -421,7 +419,7 @@ def __call__(self, x): unused_y = Dummy(x.shape, parent=scope)(x) def test_setattr_name_var_disagreement_allowed_in_lists(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -441,7 +439,7 @@ def __call__(self, x): self.assertEqual(y, jnp.array([2.0])) def test_setattr_name_var_disagreement_allowed_in_dicts(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -466,7 +464,7 @@ def __call__(self, x): self.assertEqual(y, jnp.array([2.0])) def test_submodule_var_collision_with_scope(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -485,7 +483,7 @@ def __call__(self, x): unused_y = Dummy(x.shape, parent=scope)(x) def test_submodule_var_collision_with_submodule(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -506,7 +504,7 @@ def __call__(self, x): unused_y = Dummy(x.shape, parent=scope)(x) def test_submodule_var_collision_with_params(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) class Dummy(nn.Module): xshape: Tuple[int, ...] @@ -591,7 +589,7 @@ def __call__(self, x): 'wrapped in `@compact`' ) with self.assertRaisesRegex(errors.AssignSubModuleError, msg): - Foo().init(random.PRNGKey(0), jnp.ones((1, 3))) + Foo().init(random.key(0), jnp.ones((1, 3))) def test_forgotten_compact_annotation_with_explicit_parent(self): class Bar(nn.Module): @@ -613,7 +611,7 @@ def __call__(self, x): 'wrapped in `@compact`' ) with self.assertRaisesRegex(errors.AssignSubModuleError, msg): - Foo().init(random.PRNGKey(0), jnp.ones((1, 3))) + Foo().init(random.key(0), jnp.ones((1, 3))) def test_numpy_array_shape_class_args(self): class MLP(nn.Module): @@ -626,7 +624,7 @@ def __call__(self, x): return nn.Dense(self.widths[-1])(x) test = MLP(np.array([3, 3], np.int32)) - params = test.init({'params': random.PRNGKey(42)}, jnp.ones((3, 3))) + params = test.init({'params': random.key(42)}, jnp.ones((3, 3))) _ = test.apply(params, jnp.ones((3, 3))) def test_get_local_methods(self): @@ -690,7 +688,7 @@ class Test4(Test2): def __call__(self, x): return x - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((5,)) test1 = Test(bar=4) test2 = Test2(bar=4, baz=2) @@ -762,8 +760,8 @@ def __call__(self, x): x = lyr(x) return x - x = random.uniform(random.PRNGKey(0), (5, 5)) - variables = Test().init(random.PRNGKey(0), jnp.ones((5, 5))) + x = random.uniform(random.key(0), (5, 5)) + variables = Test().init(random.key(0), jnp.ones((5, 5))) y = Test().apply(variables, x) m0 = variables['params']['layers_0']['kernel'] m1 = variables['params']['layers_2']['kernel'] @@ -840,7 +838,7 @@ def __call__(self, x): ) )""" x = jnp.ones((1, 2)) - trace, variables = mlp.init_with_output(random.PRNGKey(0), x) + trace, variables = mlp.init_with_output(random.key(0), x) self.assertEqual(trace, expected_trace) trace = mlp.apply(variables, x) self.assertEqual(trace, expected_trace) @@ -977,7 +975,7 @@ def __call__(self, x): foo = Foo() x = jnp.ones((2,)) - variables = foo.init(random.PRNGKey(0), x) + variables = foo.init(random.key(0), x) self.assertEqual(variables['params']['bar']['kernel'].shape, (2, 3)) def test_noncompact_module_frozen(self): @@ -994,7 +992,7 @@ def __call__(self): 'outside of setup method.' ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): - Foo().init(random.PRNGKey(0)) + Foo().init(random.key(0)) def test_compact_module_frozen(self): class Foo(nn.Module): @@ -1008,7 +1006,7 @@ def __call__(self): 'outside of setup method.' ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): - Foo().init(random.PRNGKey(0)) + Foo().init(random.key(0)) def test_submodule_frozen(self): class Foo(nn.Module): @@ -1023,7 +1021,7 @@ def __call__(self): 'is frozen outside of setup method.' ) with self.assertRaisesRegex(errors.SetAttributeFrozenModuleError, msg): - Foo().init(random.PRNGKey(0)) + Foo().init(random.key(0)) def test_module_call_not_implemented(self): class Foo(nn.Module): @@ -1031,7 +1029,7 @@ class Foo(nn.Module): msg = '"Foo" object has no attribute "__call__"' with self.assertRaisesRegex(AttributeError, msg): - Foo().init(random.PRNGKey(0)) + Foo().init(random.key(0)) def test_is_mutable_collection(self): class EmptyModule(nn.Module): @@ -1062,7 +1060,7 @@ def __call__(self, x): y2 = self.a(x) return y1, y2 - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((2,)) (y1, y2), unused_vars = B().init_with_output(key, x) @@ -1088,7 +1086,7 @@ def __call__(self, x): y2 = self.a(x) return y1, y2 - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((2,)) _ = B().init_with_output(key, x) @@ -1110,7 +1108,7 @@ def setup(self): msg = '"B" object has no attribute "c"' with self.assertRaisesRegex(AttributeError, msg): - A().init(random.PRNGKey(0)) + A().init(random.key(0)) def test_unbound_setup_call(self): setup_called = False @@ -1142,7 +1140,7 @@ class B(nn.Module): def __call__(self, x): return self.foo(x) - variables = A().init(random.PRNGKey(0), jnp.ones((1,))) + variables = A().init(random.key(0), jnp.ones((1,))) var_shapes = jax.tree_util.tree_map(jnp.shape, variables) ref_var_shapes = { 'params': { @@ -1167,7 +1165,7 @@ def setup(self): def __call__(self, x): return self.foo(x) - variables = B().init(random.PRNGKey(0), jnp.ones((1,))) + variables = B().init(random.key(0), jnp.ones((1,))) var_shapes = jax.tree_util.tree_map(jnp.shape, variables) ref_var_shapes = { 'params': { @@ -1210,7 +1208,7 @@ def __call__(self, x): model = Model(encoder=encoder, n_out=5) # Initialize. - key = jax.random.PRNGKey(0) + key = jax.random.key(0) x = random.uniform(key, (4, 4)) variables = model.init(key, x) @@ -1255,7 +1253,7 @@ def __call__(self, c, x): a_pytree = {'foo': A(), 'bar': A()} b = B(a_pytree) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((2, 2)) params = B(a_pytree).init(key, x, x) @@ -1304,7 +1302,7 @@ class C(nn.Module): def __call__(self, x): return dense(2)(x) + self.b(x) + self.a(x) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((2, 2)) a = A() b = B(a) @@ -1352,7 +1350,7 @@ def __call__(self, x): a = A(name='foo') b = B(a=a) - k = jax.random.PRNGKey(0) + k = jax.random.key(0) x = jnp.zeros((5, 5)) init_vars = b.init(k, x) var_shapes = jax.tree_util.tree_map(jnp.shape, init_vars) @@ -1401,7 +1399,7 @@ class B(nn.Module): def __call__(self, x): return self.A['foo'](x) + self.A['bar'](x) + self.A['baz'](x) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((2, 2)) a = A() @@ -1442,11 +1440,11 @@ def __call__(self, x, **sow_args): self.sow('intermediates', 'h', 2 * x, **sow_args) return 3 * x - variables = Foo().init(random.PRNGKey(0), 1) + variables = Foo().init(random.key(0), 1) # During init we should not collect intermediates by default... self.assertNotIn('intermediates', variables) # ...unless we override mutable. - variables = Foo().init(random.PRNGKey(0), 1, mutable=True) + variables = Foo().init(random.key(0), 1, mutable=True) self.assertEqual(variables, {'intermediates': {'h': (1, 2)}}) _, state = Foo().apply({}, 1, mutable=['intermediates']) @@ -1495,9 +1493,9 @@ def loss(params, perturbations, inputs, targets): preds = Foo().apply(variables, inputs) return jnp.square(preds - targets).mean() - x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10,)) - y = jax.random.uniform(jax.random.PRNGKey(2), shape=(10,)) - variables = Foo().init(jax.random.PRNGKey(0), x) + x = jax.random.uniform(jax.random.key(1), shape=(10,)) + y = jax.random.uniform(jax.random.key(2), shape=(10,)) + variables = Foo().init(jax.random.key(0), x) intm_grads = jax.grad(loss, argnums=1)( variables['params'], variables['perturbations'], x, y ) @@ -1517,9 +1515,9 @@ def __call__(self, x): x = self.perturb('after_multiply', x) return x - x = jax.random.uniform(jax.random.PRNGKey(1), shape=(10,)) + x = jax.random.uniform(jax.random.key(1), shape=(10,)) module = Foo() - variables = module.init(jax.random.PRNGKey(0), x) + variables = module.init(jax.random.key(0), x) params = variables['params'] perturbations = variables['perturbations'] @@ -1550,7 +1548,7 @@ def f(foo, x): x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) f_apply = nn.apply(f, foo) - y1, variables = f_init(random.PRNGKey(0), x) + y1, variables = f_init(random.key(0), x) y2 = f_apply(variables, x) self.assertEqual(y1, y2) @@ -1568,7 +1566,7 @@ def f(foo, x): foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) - y1, variables = f_init(random.PRNGKey(0), x) + y1, variables = f_init(random.key(0), x) y2 = f(foo.bind(variables), x) self.assertEqual(y1, y2) @@ -1588,7 +1586,7 @@ def f(foo, x): foo = Foo() x = jnp.ones((4,)) f_init = nn.init_with_output(f, foo) - y1, variables = f_init(random.PRNGKey(0), x) + y1, variables = f_init(random.key(0), x) foo_b = foo.bind(variables, mutable='batch_stats') y2 = f(foo_b, x) y3, new_state = nn.apply(f, foo, mutable='batch_stats')(variables, x) @@ -1615,7 +1613,7 @@ def __call__(self, x): foo = Foo() x = jnp.ones((2,)) - variables = foo.init(random.PRNGKey(0), x) + variables = foo.init(random.key(0), x) encoder, encoder_vars = foo.bind(variables).encoder.unbind() decoder, decoder_vars = foo.bind(variables).decoder.unbind() @@ -1639,7 +1637,7 @@ def __call__(self, x): return nn.Dense(2)(x) x = jnp.ones((3,)) - variables = Foo().init(random.PRNGKey(0), x) + variables = Foo().init(random.key(0), x) y = Foo().apply(variables, x) self.assertEqual(y.shape, (2,)) @@ -1657,7 +1655,7 @@ def __call__(self, x): y = super().__call__(x) return nn.Dense(3)(y) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((4, 7)) variables = Bar().init(k, x) @@ -1688,7 +1686,7 @@ def __call__(self, x): y = self.a(x) return self.b(y) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((4, 7)) variables = Bar().init(k, x) @@ -1725,10 +1723,8 @@ class MyModule2(nn.Module): def test_jit_rng_equivalance(self): model = nn.Dense(1, use_bias=False) jit_model = nn.jit(nn.Dense)(1, use_bias=False) - param = model.init(random.PRNGKey(0), np.ones((1, 1)))['params']['kernel'] - param_2 = jit_model.init(random.PRNGKey(0), np.ones((1, 1)))['params'][ - 'kernel' - ] + param = model.init(random.key(0), np.ones((1, 1)))['params']['kernel'] + param_2 = jit_model.init(random.key(0), np.ones((1, 1)))['params']['kernel'] self.assertEqual(param, param_2) def test_rng_reuse_after_rewind(self): @@ -1757,7 +1753,7 @@ def __call__(self): x1 = a() return jnp.all(x0 == x1) - k = random.PRNGKey(0) + k = random.key(0) rng_equals = B().apply({}, rngs={'dropout': k}) self.assertFalse(rng_equals) @@ -1857,7 +1853,7 @@ def __call__(self): foo = Foo() with self.assertRaisesRegex(ValueError, 'RNGs.*unbound module'): foo() - k = random.PRNGKey(0) + k = random.key(0) self.assertTrue(foo.apply({}, rngs={'bar': k})) self.assertFalse(foo.apply({}, rngs={'baz': k})) @@ -1868,7 +1864,7 @@ def __call__(self): return self.is_initializing() foo = Foo() - k = random.PRNGKey(0) + k = random.key(0) self.assertTrue(foo.init_with_output(k)[0]) self.assertFalse(foo.apply({})) @@ -1879,8 +1875,8 @@ class B(nn.Module): def __call__(self, x): return x - k = random.PRNGKey(0) - x = random.uniform(random.PRNGKey(1), (2,)) + k = random.key(0) + x = random.uniform(random.key(1), (2,)) with self.assertRaises(errors.InvalidInstanceModuleError): B.init(k, x) # B is module class, not B() a module instance @@ -1909,7 +1905,7 @@ def __call__(self, input): r = A(x=3) with self.assertRaises(errors.IncorrectPostInitOverrideError): - r.init(jax.random.PRNGKey(2), jnp.ones(3)) + r.init(jax.random.key(2), jnp.ones(3)) def test_deepcopy_unspecified_parent(self): parent_parameter = inspect.signature(DummyModule).parameters['parent'] @@ -2001,7 +1997,7 @@ def __call__(self, x): return self.bar.baz(x) module = Foo() - y, variables = module.init_with_output(jax.random.PRNGKey(0), 1) + y, variables = module.init_with_output(jax.random.key(0), 1) self.assertEqual(y, 3) def test_getattribute_triggers_setup(self): @@ -2020,7 +2016,7 @@ def __call__(self, x): return self.b.fn1(x) a = A(b=B()) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((2,)) vs = nn.init(lambda a, x: a(x), a)(k, x) y = nn.apply(lambda a, x: a.b.fn1(x), a)(vs, x) @@ -2037,7 +2033,7 @@ def __call__(self, x): return self.seq.layers[0](x) module = Foo() - variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 10))) + variables = module.init(jax.random.key(0), jnp.ones((1, 10))) def test_setup_called_bounded_submodules(self): module = nn.Sequential([ @@ -2050,7 +2046,7 @@ def test_setup_called_bounded_submodules(self): nn.Dense(2), ]) x = jnp.ones((1, 3)) - variables = module.init(jax.random.PRNGKey(0), x) + variables = module.init(jax.random.key(0), x) bound_module = module.bind(variables) self.assertIsNotNone(bound_module.layers[0].layers[0].scope) @@ -2078,7 +2074,7 @@ def __call__(self, x): module = Foo(bars=[]) module.bars = [Bar(a=1)] - variables = module.init(jax.random.PRNGKey(0), jnp.ones(())) + variables = module.init(jax.random.key(0), jnp.ones(())) bound_module = module.bind(variables) bar1 = bound_module.bars[0] @@ -2110,7 +2106,7 @@ def setup(self): def __call__(self, x): # y = self.bar(x) - y, bar_vars = self.bar.init_with_output(jax.random.PRNGKey(0), x) + y, bar_vars = self.bar.init_with_output(jax.random.key(0), x) return y, bar_vars # create foo @@ -2118,7 +2114,7 @@ def __call__(self, x): # run foo (y, bar_vars), variables = module.init_with_output( - jax.random.PRNGKey(0), jnp.ones(()) + jax.random.key(0), jnp.ones(()) ) self.assertIn('params', bar_vars) @@ -2154,7 +2150,7 @@ def __call__(self, x): b = Unshared(shared=sh) module = Super(a=a, b=b) - rng = jax.random.PRNGKey(0) + rng = jax.random.key(0) params = module.init(rng, jnp.ones(1))['params'] module.apply({'params': params}, jnp.ones(1)) # works as expected @@ -2242,7 +2238,7 @@ class Foo(nn.Module): Foo(a=1, parent=None) # type: ignore[call-arg] def test_module_path_empty(self): - rngkey = jax.random.PRNGKey(0) + rngkey = jax.random.key(0) scope = Scope({}, {'params': rngkey}, mutable=['params']) m1 = DummyModule(parent=scope) @@ -2304,8 +2300,8 @@ def __call__(self, x): return x a = A() - k = random.PRNGKey(0) - x = random.uniform(random.PRNGKey(42), (2,)) + k = random.key(0) + x = random.uniform(random.key(42), (2,)) _ = a.init(k, x) expected_module_paths = [ (), @@ -2366,7 +2362,7 @@ def __call__(self, x): mod = CompactModule() x = jnp.ones(shape=(1, 3)) - variables = mod.init(jax.random.PRNGKey(0), x) + variables = mod.init(jax.random.key(0), x) call_modules = [] def log_interceptor(f, args, kwargs, context): @@ -2395,7 +2391,7 @@ def __call__(self, x): mod = SetupModule() x = jnp.ones(shape=(1, 3)) - variables = mod.init(jax.random.PRNGKey(0), x) + variables = mod.init(jax.random.key(0), x) call_modules = [] log = [] @@ -2565,7 +2561,7 @@ def sample_from_prior(rng, inp): with patch.object(gc, 'collect', return_value=0): with jax.checking_leaks(): for i in range(5): - rngs = jax.random.split(jax.random.PRNGKey(23), 100) + rngs = jax.random.split(jax.random.key(23), 100) out = sample_from_prior(rngs, np.ones((4, 50))) out.block_until_ready() del out, rngs @@ -2590,7 +2586,7 @@ def __call__(self, x): with set_config('flax_preserve_adopted_names', True): foo = Foo(name='foo') bar = Bar(sub=foo) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('foo' in vs['params'], 'relaxed naming failure') @@ -2599,7 +2595,7 @@ def __call__(self, x): with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar(sub=foo) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'old policy naming failure') @@ -2630,7 +2626,7 @@ def __call__(self, x): with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar1(sub=foo) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('foo' in vs['params'], 'adoption naming failure') @@ -2639,7 +2635,7 @@ def __call__(self, x): with set_config('flax_preserve_adopted_names', True): foo = Foo(name='foo') bar = Bar2(sub=foo) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'adoption naming failure') @@ -2671,7 +2667,7 @@ def __call__(self, x): foo = Foo(name='foo') bar = Bar(sub=foo, name='bar') baz = Baz(sub=bar) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = baz.init(k, x) self.assertTrue('bar' in vs['params'], 'adoption naming failure') @@ -2697,7 +2693,7 @@ def __call__(self, x): foo1 = Foo(name='foo') foo2 = Foo(name='foo') bar = Bar(sub1=foo1, sub2=foo2) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) with self.assertRaises(errors.NameInUseError): vs = bar.init(k, x) @@ -2719,7 +2715,7 @@ def __call__(self, x): with set_config('flax_preserve_adopted_names', True): foo = Foo(name=None) bar = Bar(sub=foo) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'relaxed naming failure') @@ -2728,7 +2724,7 @@ def __call__(self, x): with set_config('flax_preserve_adopted_names', False): foo = Foo(name='foo') bar = Bar(sub=foo) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = bar.init(k, x) self.assertTrue('sub' in vs['params'], 'old policy naming failure') @@ -2744,7 +2740,7 @@ def __call__(self, x): return x + p foo = Foo(name='foo') - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = foo.init(k, x) @@ -2758,7 +2754,7 @@ def __call__(self, x): return x + v1.value + v2.value foo = Foo(name='foo') - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) vs = foo.init(k, x) @@ -2773,7 +2769,7 @@ def __call__(self, x): return x + v1.value + v2.value + v3.value foo = Foo(name='foo') - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((1,)) with self.assertRaises(errors.NameInUseError): vs = foo.init(k, x) @@ -2785,13 +2781,13 @@ def test_frozendict_flag(self): with set_config('flax_return_frozendict', True): x = jnp.zeros((2, 3)) layer = nn.Dense(5) - params = layer.init(random.PRNGKey(0), x) + params = layer.init(random.key(0), x) self.assertTrue(isinstance(params, FrozenDict)) with set_config('flax_return_frozendict', False): x = jnp.zeros((2, 3)) layer = nn.Dense(5) - params = layer.init(random.PRNGKey(0), x) + params = layer.init(random.key(0), x) self.assertTrue(isinstance(params, dict)) diff --git a/tests/linen/linen_recurrent_test.py b/tests/linen/linen_recurrent_test.py index 2f09a487bf..74c900128b 100644 --- a/tests/linen/linen_recurrent_test.py +++ b/tests/linen/linen_recurrent_test.py @@ -40,7 +40,7 @@ def test_rnn_basic_forward(self): rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) - variables = rnn.init(jax.random.PRNGKey(0), xs) + variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) @@ -63,7 +63,7 @@ def test_rnn_multiple_batch_dims(self): rnn = nn.RNN(nn.LSTMCell(channels_out), return_carry=True) xs = jnp.ones((*batch_dims, seq_len, channels_in)) - variables = rnn.init(jax.random.PRNGKey(0), xs) + variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) @@ -86,7 +86,7 @@ def test_rnn_unroll(self): rnn = nn.RNN(nn.LSTMCell(channels_out), unroll=10, return_carry=True) xs = jnp.ones((batch_size, seq_len, channels_in)) - variables = rnn.init(jax.random.PRNGKey(0), xs) + variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) @@ -109,7 +109,7 @@ def test_rnn_time_major(self): rnn = nn.RNN(nn.LSTMCell(channels_out), time_major=True, return_carry=True) xs = jnp.ones((seq_len, batch_size, channels_in)) - variables = rnn.init(jax.random.PRNGKey(0), xs) + variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs) @@ -142,7 +142,7 @@ def test_rnn_with_spatial_dimensions(self): ) xs = jnp.ones((batch_size, seq_len, *image_size, channels_in)) - variables = rnn.init(jax.random.PRNGKey(0), xs) + variables = rnn.init(jax.random.key(0), xs) ys: jnp.ndarray carry, ys = rnn.apply(variables, xs, return_carry=True) @@ -174,11 +174,9 @@ def test_numerical_equivalence(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) + (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) - cell_carry = rnn.cell.initialize_carry( - jax.random.PRNGKey(0), xs[:, 0].shape - ) + cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) cell_params = variables['params']['cell'] for i in range(seq_len): @@ -195,7 +193,7 @@ def test_numerical_equivalence_with_mask(self): channels_in = 5 channels_out = 6 - key = jax.random.PRNGKey(0) + key = jax.random.key(0) seq_lengths = jax.random.randint( key, (batch_size,), minval=1, maxval=seq_len + 1 ) @@ -205,12 +203,10 @@ def test_numerical_equivalence_with_mask(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = rnn.init_with_output( - jax.random.PRNGKey(0), xs, seq_lengths=seq_lengths + jax.random.key(0), xs, seq_lengths=seq_lengths ) - cell_carry = rnn.cell.initialize_carry( - jax.random.PRNGKey(0), xs[:, 0].shape - ) + cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) cell_params = variables['params']['cell'] carries = [] @@ -238,14 +234,12 @@ def test_numerical_equivalence_single_batch(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) + (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_params = variables['params']['cell'] for batch_idx in range(batch_size): - cell_carry = rnn.cell.initialize_carry( - jax.random.PRNGKey(0), xs[:1, 0].shape - ) + cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = rnn.cell.apply( @@ -272,16 +266,14 @@ def test_numerical_equivalence_single_batch_nn_scan(self): )(channels_out) xs = jnp.ones((batch_size, seq_len, channels_in)) - carry = rnn.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) + carry = rnn.initialize_carry(jax.random.key(0), xs[:, 0].shape) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output( - jax.random.PRNGKey(0), carry, xs - ) + (carry, ys), variables = rnn.init_with_output(jax.random.key(0), carry, xs) cell_params = variables['params'] for batch_idx in range(batch_size): - cell_carry = cell.initialize_carry(jax.random.PRNGKey(0), xs[:1, 0].shape) + cell_carry = cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = cell.apply( @@ -301,11 +293,11 @@ def test_numerical_equivalence_single_batch_jax_scan(self): channels_out = 6 xs = jax.random.uniform( - jax.random.PRNGKey(0), (batch_size, seq_len, channels_in) + jax.random.key(0), (batch_size, seq_len, channels_in) ) cell: nn.LSTMCell = nn.LSTMCell(channels_out) - carry = cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) - variables = cell.init(jax.random.PRNGKey(0), carry, xs[:, 0]) + carry = cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) + variables = cell.init(jax.random.key(0), carry, xs[:, 0]) cell_params = variables['params'] def scan_fn(carry, x): @@ -315,7 +307,7 @@ def scan_fn(carry, x): carry, ys = jax.lax.scan(scan_fn, carry, xs.swapaxes(0, 1)) ys = ys.swapaxes(0, 1) - cell_carry = cell.initialize_carry(jax.random.PRNGKey(0), xs[:, 0].shape) + cell_carry = cell.initialize_carry(jax.random.key(0), xs[:, 0].shape) for i in range(seq_len): cell_carry, y = cell.apply( @@ -335,14 +327,12 @@ def test_reverse(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) + (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_params = variables['params']['cell'] for batch_idx in range(batch_size): - cell_carry = rnn.cell.initialize_carry( - jax.random.PRNGKey(0), xs[:1, 0].shape - ) + cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = rnn.cell.apply( @@ -373,14 +363,12 @@ def test_reverse_but_keep_order(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - (carry, ys), variables = rnn.init_with_output(jax.random.PRNGKey(0), xs) + (carry, ys), variables = rnn.init_with_output(jax.random.key(0), xs) cell_params = variables['params']['cell'] for batch_idx in range(batch_size): - cell_carry = rnn.cell.initialize_carry( - jax.random.PRNGKey(0), xs[:1, 0].shape - ) + cell_carry = rnn.cell.initialize_carry(jax.random.key(0), xs[:1, 0].shape) for i in range(seq_len): cell_carry, y = rnn.cell.apply( @@ -441,7 +429,7 @@ def test_flip_sequence_time_major_more_feature_dims(self): def test_basic_seq_lengths(self): x = jnp.ones((2, 10, 6)) lstm = nn.RNN(nn.LSTMCell(265)) - variables = lstm.init(jax.random.PRNGKey(0), x) + variables = lstm.init(jax.random.key(0), x) y = lstm.apply(variables, x, seq_lengths=jnp.array([5, 5])) @@ -459,7 +447,7 @@ def test_bidirectional(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - ys, variables = bdirectional.init_with_output(jax.random.PRNGKey(0), xs) + ys, variables = bdirectional.init_with_output(jax.random.key(0), xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) @@ -474,7 +462,7 @@ def test_shared_cell(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - ys, variables = bdirectional.init_with_output(jax.random.PRNGKey(0), xs) + ys, variables = bdirectional.init_with_output(jax.random.key(0), xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out * 2)) @@ -492,7 +480,7 @@ def test_custom_merge_fn(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray - ys, variables = bdirectional.init_with_output(jax.random.PRNGKey(0), xs) + ys, variables = bdirectional.init_with_output(jax.random.key(0), xs) self.assertEqual(ys.shape, (batch_size, seq_len, channels_out)) @@ -511,7 +499,7 @@ def test_return_carry(self): xs = jnp.ones((batch_size, seq_len, channels_in)) ys: jnp.ndarray (carry, ys), variables = bdirectional.init_with_output( - jax.random.PRNGKey(0), xs + jax.random.key(0), xs ) carry_forward, carry_backward = carry @@ -539,7 +527,7 @@ def test_constructor(self, cell_type): cell_type=[nn.LSTMCell, nn.GRUCell, nn.OptimizedLSTMCell] ) def test_initialize_carry(self, cell_type): - key = jax.random.PRNGKey(0) + key = jax.random.key(0) with self.assertRaisesRegex(TypeError, 'The RNNCellBase API has changed'): cell_type.initialize_carry(key, (2,), 3) diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 54a46064a9..3ccb3f593a 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -148,7 +148,7 @@ def test_pooling_no_batch_dims(self): class NormalizationTest(parameterized.TestCase): def test_batch_norm(self): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (4, 3, 2)) model_cls = nn.BatchNorm(momentum=0.9, use_running_average=False) @@ -170,7 +170,7 @@ def test_batch_norm(self): ) def test_batch_norm_complex(self): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (4, 3, 2), dtype=jnp.complex64) model_cls = nn.BatchNorm( @@ -202,7 +202,7 @@ def test_batch_norm_complex(self): {'reduction_axes': -1, 'use_fast_variance': False}, ) def test_layer_norm(self, reduction_axes, use_fast_variance=True): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 3, 4)) @@ -227,7 +227,7 @@ def test_layer_norm(self, reduction_axes, use_fast_variance=True): {'reduction_axes': -1}, {'reduction_axes': 1}, {'reduction_axes': (1, 2)} ) def test_rms_norm(self, reduction_axes): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 3, 4)) @@ -243,7 +243,7 @@ def test_rms_norm(self, reduction_axes): np.testing.assert_allclose(y_one_liner, y, atol=1e-4) def test_group_norm(self): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) @@ -264,7 +264,7 @@ def test_group_norm(self): np.testing.assert_allclose(y_test, y, atol=1e-4) def test_group_norm_raises(self): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) e = 1e-5 x = random.normal(key1, (2, 5, 4, 4, 32)) @@ -288,9 +288,9 @@ def __call__(self, x): x = norm(x) return x, norm(x) - key = random.PRNGKey(0) + key = random.key(0) model = Foo() - x = random.normal(random.PRNGKey(1), (2, 4)) + x = random.normal(random.key(1), (2, 4)) (y1, y2), variables = model.init_with_output(key, x) np.testing.assert_allclose(y1, y2, rtol=1e-4) @@ -298,7 +298,7 @@ def __call__(self, x): class StochasticTest(absltest.TestCase): def test_dropout(self): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) module = nn.Dropout(rate=0.5) y1 = module.apply( @@ -318,7 +318,7 @@ def test_dropout(self): self.assertTrue(np.all(y1 == y2)) def test_dropout_rate_stats(self): - rootkey = random.PRNGKey(0) + rootkey = random.key(0) for rate in np.arange(0.1, 1.0, 0.1): rootkey, subkey = random.split(rootkey) module = nn.Dropout(rate=rate) @@ -337,7 +337,7 @@ def test_dropout_rate_stats(self): self.assertTrue(keep_rate - delta < frac < keep_rate + delta) def test_dropout_rate_limits(self): - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2, key3 = random.split(rng, 3) inputs = jnp.ones((20, 20)) d0 = nn.Dropout(rate=0.0) @@ -363,7 +363,7 @@ def __call__(self, x): module = Foo() x1, x2 = module.apply( - {}, jnp.ones((20, 20)), rngs={'dropout': random.PRNGKey(0)} + {}, jnp.ones((20, 20)), rngs={'dropout': random.key(0)} ) np.testing.assert_array_equal(x1, x2) @@ -374,7 +374,7 @@ class RecurrentTest(absltest.TestCase): def test_lstm(self): lstm = nn.LSTMCell(features=4) - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = lstm.initialize_carry(rng, x.shape) @@ -401,7 +401,7 @@ def test_lstm(self): def test_gru(self): gru = nn.GRUCell(features=4) - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) carry0 = gru.initialize_carry(rng, x.shape) @@ -424,7 +424,7 @@ def test_gru(self): def test_complex_input_gru(self): gru = nn.GRUCell(features=4) - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3), dtype=jnp.complex64) carry0 = gru.initialize_carry(rng, x.shape) @@ -435,7 +435,7 @@ def test_complex_input_gru(self): def test_convlstm(self): lstm = nn.ConvLSTMCell(features=6, kernel_size=(3, 3)) - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 4, 4, 3)) c0, h0 = lstm.initialize_carry(rng, x.shape) @@ -457,7 +457,7 @@ def test_convlstm(self): def test_optimized_lstm_cell_matches_regular(self): # Create regular LSTMCell. lstm = nn.LSTMCell(features=4) - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = lstm.initialize_carry(rng, x.shape) @@ -467,7 +467,7 @@ def test_optimized_lstm_cell_matches_regular(self): # Create OptimizedLSTMCell. lstm_opt = nn.OptimizedLSTMCell(features=4) - rng = random.PRNGKey(0) + rng = random.key(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = lstm_opt.initialize_carry(rng, x.shape) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 079bb00491..6ef18b140b 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -85,7 +85,7 @@ def __call__(self, inputs): class TransformTest(absltest.TestCase): def test_jit(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = TransformedMLP(features=[3, 4, 5]) @@ -97,7 +97,7 @@ def test_jit(self): self.assertTrue(np.all(y1 == y2)) def test_jit_decorated(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = decorated_MLP()(features=[3, 4, 5]) @@ -109,7 +109,7 @@ def test_jit_decorated(self): self.assertTrue(np.all(y1 == y2)) def test_remat(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = TransformedMLP(features=[3, 4, 5]) @@ -121,7 +121,7 @@ def test_remat(self): self.assertTrue(np.all(y1 == y2)) def test_remat_decorated(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) normal_model = decorated_MLP()(features=[3, 4, 5]) @@ -141,7 +141,7 @@ class ConditionalReLU(nn.Module): def __call__(self, input, apply_relu: bool = False): return nn.relu(input) if apply_relu else input - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((4, 4)) * -1 remat_model = nn.remat(ConditionalReLU)() p = remat_model.init(key, x) @@ -172,7 +172,7 @@ def __call__(self, inputs, train: bool): foo = FooRemat(train_is_static=True) x = jnp.empty((1, 2)) - variables = foo.init(random.PRNGKey(0), x, True) + variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) @@ -180,7 +180,7 @@ def __call__(self, inputs, train: bool): FooRemat = nn.remat(Foo, static_argnums=()) foo = FooRemat(train_is_static=False) - variables = foo.init(random.PRNGKey(0), x, True) + variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) @@ -200,7 +200,7 @@ def __call__(self, inputs, train: bool): foo = FooTrainStatic() x = jnp.empty((1, 2)) - variables = foo.init(random.PRNGKey(0), x, True) + variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) @@ -216,12 +216,12 @@ def __call__(self, inputs, train: bool): # set train as a non-static arguments foo = FooTrainDynamic() - variables = foo.init(random.PRNGKey(0), x, True) + variables = foo.init(random.key(0), x, True) y = foo.apply(variables, x, False) self.assertEqual(y.shape, (1, 3)) def test_vmap(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key1, (5, 4, 4)) @@ -247,7 +247,7 @@ def vmap(cls): np.testing.assert_allclose(y1, y2, atol=1e-7) def test_vmap_decorated(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key1, (5, 4, 4)) @@ -273,7 +273,7 @@ def vmap(fn): np.testing.assert_allclose(y1, y2, atol=1e-7) def test_vmap_batchnorm(self): - key1, key2 = random.split(random.PRNGKey(3), 2) + key1, key2 = random.split(random.key(3), 2) x = random.uniform(key1, (4, 4)) x2 = random.uniform(key1, (5, 4, 4)) @@ -318,9 +318,9 @@ def __call__(self, c, xs): ) return LSTM(self.features, name='lstm_cell')(c, xs) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) xs = random.uniform(key1, (5, 3, 2)) - dummy_rng = random.PRNGKey(0) + dummy_rng = random.key(0) init_carry = nn.LSTMCell(2).initialize_carry(dummy_rng, xs[0].shape) model = SimpleScan(2) init_variables = model.init(key2, init_carry, xs) @@ -355,10 +355,10 @@ def __call__(self, c, b, xs): assert b.shape == (4,) return nn.LSTMCell(self.features, name='lstm_cell')(c, xs) - key1, key2 = random.split(random.PRNGKey(0), 2) + key1, key2 = random.split(random.key(0), 2) xs = random.uniform(key1, (4, 3, 2)) b = jnp.ones((4,)) - dummy_rng = random.PRNGKey(0) + dummy_rng = random.key(0) init_carry = nn.LSTMCell(2).initialize_carry(dummy_rng, xs[0].shape) model = SimpleScan(2) init_variables = model.init(key2, init_carry, b, xs) @@ -413,7 +413,7 @@ def __call__(self, x): return res x = jnp.ones((10, 10)) - rngs = random.PRNGKey(0) + rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( @@ -461,7 +461,7 @@ def __call__(self, x): return res x = jnp.ones((1, 1)) - rngs = random.PRNGKey(0) + rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( @@ -507,7 +507,7 @@ def __call__(self, x): return res x = jnp.ones((1, 1)) - rngs = random.PRNGKey(0) + rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( @@ -562,7 +562,7 @@ def __call__(self, x): return res x = jnp.ones((1, 1)) - rngs = random.PRNGKey(0) + rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( @@ -618,7 +618,7 @@ def __call__(self, x): return res x = jnp.ones((1, 1)) - rngs = random.PRNGKey(0) + rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) _, new_vars = Test(parent=None).apply(init_vars, x, mutable=['counter']) self.assertEqual( @@ -661,7 +661,7 @@ def __call__(self, x): return res x = jnp.ones((3, 1, 2)) - rngs = random.PRNGKey(0) + rngs = random.key(0) init_vars = Test(parent=None).init(rngs, x) y = Test(parent=None).apply(init_vars, x) self.assertEqual( @@ -688,7 +688,7 @@ def __call__(self, x): variable_axes={'params': 0}, split_rngs={'params': True}, ) - variables = FooVmap().init(random.PRNGKey(0), jnp.ones((4,))) + variables = FooVmap().init(random.key(0), jnp.ones((4,))) self.assertEqual(variables['params']['test'].shape, (4,)) def test_nested_module_args_vmap(self): @@ -724,7 +724,7 @@ def __call__(self, x): c = C(b) return c(x) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((10, 10)) p = D().init(key, x) @@ -771,7 +771,7 @@ def __call__(self, x): c = C(a2, b) return c(x) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((10, 10)) p = D().init(key, x) @@ -825,7 +825,7 @@ def nested_repeat(mdl): mdl = partial(Repeat, mdl) return mdl() - _ = nested_repeat(Counter).init(random.PRNGKey(0), jnp.ones((2,))) + _ = nested_repeat(Counter).init(random.key(0), jnp.ones((2,))) # setup_cntr == 128 due to 1 call in Counter.setup by _validate_setup # and 1 further "real" call. self.assertEqual(setup_cntr, 128) @@ -859,7 +859,7 @@ def __call__(self, x): y2 = self.a.bar(x) return y1, y2 - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((2,)) (y1, y2), _ = B().init_with_output(key, x) np.testing.assert_array_equal(y1, y2) @@ -912,7 +912,7 @@ def __call__(self, x): c = C(a2, b) return c(x) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((10, 10)) p1 = D().init(key, x) y1 = D().apply(p1, x) @@ -969,7 +969,7 @@ def __call__(self, c, x): split_rngs={'params': False}, )(As) - key = random.PRNGKey(0) + key = random.key(0) x = jnp.ones((10, 2)) p = B(As).init(key, x, x) @@ -995,7 +995,7 @@ def __call__(self, c, x): ) def test_partially_applied_module_constructor_transform(self): - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((3, 4, 4)) dense = partial(nn.Dense, use_bias=False) vmap_dense = nn.vmap( @@ -1011,7 +1011,7 @@ def test_partially_applied_module_constructor_transform(self): self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) def test_partial_module_method(self): - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((3, 4, 4)) class Foo(nn.Module): @@ -1049,7 +1049,7 @@ def mutate_variable_in_method(self, x, baz): baz.value += x return baz.value - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((1,)) variables = Test().init(k, x) np.testing.assert_allclose( @@ -1091,7 +1091,7 @@ def __call__(self, x): def call_instance_arg_in_method(self, x, inner): return inner(x) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((1,)) variables = Test().init(k, x) np.testing.assert_allclose( @@ -1141,7 +1141,7 @@ def __call__(self, x): outer = Outer(name='outer') return outer(inner, x) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((1,)) variables = Test().init(k, x) np.testing.assert_allclose( @@ -1187,7 +1187,7 @@ def __call__(self, x): y = VarUser(baz)(x) return y - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((1,)) variables = VarPasser().init(k, x) np.testing.assert_allclose( @@ -1292,7 +1292,7 @@ def __call__(self, x): x = jnp.ones((2, 4)) ae = TiedAutencoder(4, 5) - variables = ae.init(random.PRNGKey(0), x) + variables = ae.init(random.key(0), x) param_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(param_shapes, {'Dense_0': {'kernel': (4, 5)}}) @@ -1309,7 +1309,7 @@ def sign(x): bw = BitWeights() x = jnp.ones((2, 4)) - y, variables = bw.init_with_output(random.PRNGKey(0), x) + y, variables = bw.init_with_output(random.key(0), x) y_2 = bw.apply(variables, x) np.testing.assert_allclose(y, y_2) @@ -1323,7 +1323,7 @@ def __call__(self, x): x = jnp.ones((2, 8)) model = BigModel() - variables = model.init(random.PRNGKey(0), x) + variables = model.init(random.key(0), x) param_shapes = jax.tree_util.tree_map(jnp.shape, variables['params']) self.assertEqual(param_shapes['dense_stack']['kernel'], (100, 8, 8)) self.assertEqual(param_shapes['dense_stack']['bias'], (100, 8)) @@ -1348,7 +1348,7 @@ def __call__(self, x, y): x = jnp.array([1.0, 2.0, 3.0]) y = jnp.array([4.0, 5.0, 6.0]) - params = Foo().init(random.PRNGKey(0), x, y) + params = Foo().init(random.key(0), x, y) params_grad, x_grad, y_grad = Foo().apply(params, x, y) self.assertEqual( params_grad, @@ -1382,7 +1382,7 @@ def __call__(self, x): return out_t x = jnp.ones((3,)) - params = Foo().init(random.PRNGKey(0), x) + params = Foo().init(random.key(0), x) y_t = Foo().apply(params, x) np.testing.assert_allclose(y_t, jnp.ones_like(x)) @@ -1417,7 +1417,7 @@ def __call__(self, x): return x a = A(b=B(c=C())) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((1,), jnp.float32) vs = a.init(k, x) y, vs_new = a.apply( @@ -1454,7 +1454,7 @@ def bwd(vjp_fn, y_t): return sign_grad(nn.Dense(1), x).reshape(()) x = jnp.ones((2,)) - variables = Foo().init(random.PRNGKey(0), x) + variables = Foo().init(random.key(0), x) grad = jax.grad(Foo().apply)(variables, x) for grad_leaf in jax.tree_util.tree_leaves(grad): self.assertTrue(jnp.all(jnp.abs(grad_leaf) == 1.0)) @@ -1474,7 +1474,7 @@ def helper(self, x, m): def __call__(self, x): return self.helper(x, self.inner) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2,)) vs_foo = Foo().init(k, x) @@ -1513,7 +1513,7 @@ def __call__(self, x): a = nn.Dense(2, name='a') return self.helper(x, a) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2,)) with self.assertRaises(errors.NameInUseError): vs = Foo().init(k, x) @@ -1535,7 +1535,7 @@ def setup(self): def __call__(self, x): return x - k = random.PRNGKey(0) + k = random.key(0) x = jnp.array([1.0]) with self.assertRaises(errors.NameInUseError): @@ -1566,7 +1566,7 @@ def helper(self, x, ms): def __call__(self, x): return self.helper(x, self.inners) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2,)) vs_0 = Foo().init(k, x) @@ -1591,7 +1591,7 @@ def setup(self): def __call__(self, x): return x - k = random.PRNGKey(0) + k = random.key(0) x = jnp.array([1.0]) msg = r'Could not create submodule "subs_0".*' @@ -1611,7 +1611,7 @@ def __call__(self, x): scanbody = nn.scan( Body, variable_axes={'params': 0}, split_rngs={'params': True}, length=2 ) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((1,)) vs = scanbody().init(k, x) y = scanbody().apply(vs, x) @@ -1656,8 +1656,8 @@ def __call__(self, x): z, _ = sf.method_1(y, ys) return z - k = random.PRNGKey(0) - x = random.uniform(random.PRNGKey(1), (2, 2)) + k = random.key(0) + x = random.uniform(random.key(1), (2, 2)) vs = Bar().init(k, x) y = Bar().apply(vs, x) @@ -1679,7 +1679,7 @@ def __call__(self, x): x = nn.jit(Foo)(dense, dense)(x) return x - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((2, 2)) _ = Bar().init(k, x) @@ -1700,7 +1700,7 @@ def __call__(self, x): x = nn.jit(Foo)(dense)(x, dense) return x - k = random.PRNGKey(0) + k = random.key(0) x = jnp.zeros((2, 2)) _ = Bar().init(k, x) @@ -1730,7 +1730,7 @@ def setup_helper(self): def __call__(self, x): return self.b(self.a(x)) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) vs = JitFoo().init(k, x) y0 = JitFoo().apply(vs, x) @@ -1743,7 +1743,7 @@ class Foo(nn.Module): @nn.compact def __call__(self, x): - key_zero = random.PRNGKey(0) + key_zero = random.key(0) key_zero = jnp.broadcast_to(key_zero, (2, *key_zero.shape)) self.param('inc', lambda _: 1) self.put_variable('state', 'acc', 0) @@ -1787,7 +1787,7 @@ def body_fn(mdl, c): {}, x, mutable=True, - rngs={'params': random.PRNGKey(1), 'loop': random.PRNGKey(2)}, + rngs={'params': random.key(1), 'loop': random.key(2)}, ) self.assertEqual(vars['state']['acc'], x) np.testing.assert_array_equal( @@ -1842,7 +1842,7 @@ def c_fn(mdl, x): x = jnp.ones((1, 3)) foo = Foo() - y1, vars = foo.init_with_output(random.PRNGKey(0), x, 0) + y1, vars = foo.init_with_output(random.key(0), x, 0) self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 0, 'c_count': 0}) y2, updates = foo.apply(vars, x, 1, mutable='state') vars = copy(vars, updates) @@ -1882,7 +1882,7 @@ def fn(mdl, x): x = jnp.ones((1, 3)) foo = Foo() - y1, vars = foo.init_with_output(random.PRNGKey(0), x, 0) + y1, vars = foo.init_with_output(random.key(0), x, 0) self.assertEqual(vars['state'], {'0_count': 1, '1_count': 0, '2_count': 0}) y2, updates = foo.apply(vars, x, 1, mutable='state') vars = copy(vars, updates) @@ -1924,7 +1924,7 @@ def __call__(self, x): return nn.checkpoint(nn.Dense(2))(x) with self.assertRaises(errors.TransformTargetError): - Foo().init(random.PRNGKey(0), jnp.zeros((2, 3))) + Foo().init(random.key(0), jnp.zeros((2, 3))) def test_scan_compact_count(self): class Foo(nn.Module): @@ -1945,7 +1945,7 @@ def body_fn(mdl, x): m = Foo() x = jnp.ones((3,)) - v = m.init(jax.random.PRNGKey(0), x) + v = m.init(jax.random.key(0), x) self.assertEqual(v['params']['Dense_0']['kernel'].shape, (5, 3, 3)) m.apply(v, x) @@ -1969,7 +1969,7 @@ def __call__(self, x): cond_model = CondModel() output, init_params = jax.jit(cond_model.init_with_output)( - jax.random.PRNGKey(0), x=jnp.ones(3) + jax.random.key(0), x=jnp.ones(3) ) def test_add_metadata_axis(self): @@ -1999,7 +1999,7 @@ class Test(nn.Module): def __call__(self, x): return Foo(name='foo')(x) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((4, 4), dtype=jnp.float32) vs = Test().init(k, x) y = Test().apply(vs, x) @@ -2063,7 +2063,7 @@ def __call__(self, x): a = self.a(x) return a + b - k = random.PRNGKey(0) + k = random.key(0) x = random.randint(k, (2, 2), minval=0, maxval=10) vs = C().init(k, x) y = C().apply(vs, x) diff --git a/tests/linen/partitioning_test.py b/tests/linen/partitioning_test.py index 7b79238b82..ddade03e88 100644 --- a/tests/linen/partitioning_test.py +++ b/tests/linen/partitioning_test.py @@ -201,7 +201,7 @@ def __call__(self, x): ) return x + foo - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) _ = ParamTest().init(k, x) @@ -220,7 +220,7 @@ def __call__(self, x): return x + foo p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = ParamTest().init(k, x) @@ -252,7 +252,7 @@ def __call__(self, x): return x + foo['a'] p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = ParamTest().init(k, x) @@ -287,7 +287,7 @@ def __call__(self, x): ) return x + foo.value - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) _ = VarTest().init(k, x) @@ -301,7 +301,7 @@ def __call__(self, x): ) return x + foo.value - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) variables = VarTest().init(k, x) logical_axis_names = partitioning.get_axis_names(variables['test_axes']) @@ -318,7 +318,7 @@ def __call__(self, x): return x + foo.value p_rules = (('foo', 'model'), ('bar', 'data'), ('baz', None)) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = VarTest().init(k, x) @@ -355,7 +355,7 @@ def __call__(self, x): ('bar', 'data'), ('baz', None), ) - k = random.PRNGKey(0) + k = random.key(0) x = jnp.ones((2, 2)) with partitioning.axis_rules(p_rules): variables = VarTest().init(k, x) @@ -377,7 +377,7 @@ def test_scan_with_axes(self): B, L, E = 8, 4, 32 # pylint: disable=invalid-name # fake inputs x = jnp.ones((B, E)) - k = random.PRNGKey(0) + k = random.key(0) class SinDot(nn.Module): depth: int @@ -487,7 +487,7 @@ def __call__(self, x): # check that regular Food module is correct with partitioning.axis_rules(p_rules): - variables = Foo().init(jax.random.PRNGKey(0), jnp.array([1, 2, 3])) + variables = Foo().init(jax.random.key(0), jnp.array([1, 2, 3])) variables = unfreeze(variables) variables['params'] = jax.tree_util.tree_map( lambda x: x.shape, variables['params'] @@ -505,7 +505,7 @@ def __call__(self, x): # check that FooVmapped adds 'vmap_axis' to axis 1 with partitioning.axis_rules(p_rules): variables = Vmapped().init( - jax.random.PRNGKey(0), jnp.array([[1, 2, 3], [4, 5, 6]]) + jax.random.key(0), jnp.array([[1, 2, 3], [4, 5, 6]]) ) variables = unfreeze(variables) variables['params'] = jax.tree_util.tree_map( @@ -547,7 +547,7 @@ def __call__(self, x): @jax.jit def create_state(): module = Foo() - variables = module.init(random.PRNGKey(0), jnp.zeros((8, 4))) + variables = module.init(random.key(0), jnp.zeros((8, 4))) logical_spec = nn.get_partition_spec(variables) shardings = nn.logical_to_mesh_sharding(logical_spec, mesh, rules) variables = jax.lax.with_sharding_constraint(variables, shardings) diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index eb2f5624b7..487cca204e 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -124,7 +124,7 @@ def test_module_summary(self): module = CNN(test_sow=False) table = summary._get_module_table(module, depth=None, show_repeated=True)( - {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.key(0), "params": random.key(1)}, x, training=True, mutable=True, @@ -219,7 +219,7 @@ def test_module_summary_with_depth(self): module = CNN(test_sow=False) table = summary._get_module_table(module, depth=1, show_repeated=True)( - {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.key(0), "params": random.key(1)}, x, training=True, mutable=True, @@ -284,7 +284,7 @@ def test_tabulate(self): module = CNN(test_sow=False) module_repr = module.tabulate( - {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.key(0), "params": random.key(1)}, x, training=True, console_kwargs=CONSOLE_TEST_KWARGS, @@ -326,7 +326,7 @@ def test_tabulate_with_sow(self): module = CNN(test_sow=True) module_repr = module.tabulate( - {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.key(0), "params": random.key(1)}, x, training=True, console_kwargs=CONSOLE_TEST_KWARGS, @@ -342,7 +342,7 @@ def test_tabulate_with_method(self): module = CNN(test_sow=False) module_repr = module.tabulate( - {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.key(0), "params": random.key(1)}, x, training=True, method=CNN.cnn_method, @@ -364,7 +364,7 @@ def test_tabulate_function(self): module_repr = nn.tabulate( module, - {"dropout": random.PRNGKey(0), "params": random.PRNGKey(1)}, + {"dropout": random.key(0), "params": random.key(1)}, console_kwargs=CONSOLE_TEST_KWARGS, )( x, @@ -404,7 +404,7 @@ class LSTM(nn.Module): @nn.compact def __call__(self, x): carry = nn.LSTMCell(self.features).initialize_carry( - random.PRNGKey(0), x[:, 0].shape + random.key(0), x[:, 0].shape ) ScanLSTM = nn.scan( nn.LSTMCell, @@ -419,7 +419,7 @@ def __call__(self, x): with jax.check_tracer_leaks(True): module_repr = lstm.tabulate( - random.PRNGKey(0), + random.key(0), x=jnp.ones((32, 128, 64)), console_kwargs=CONSOLE_TEST_KWARGS, ) @@ -439,7 +439,7 @@ class LSTM(nn.Module): @nn.compact def __call__(self, x): carry = nn.LSTMCell(self.features).initialize_carry( - random.PRNGKey(0), x[:, 0].shape + random.key(0), x[:, 0].shape ) ScanLSTM = nn.scan( nn.LSTMCell, @@ -454,7 +454,7 @@ def __call__(self, x): with jax.check_tracer_leaks(True): module_repr = lstm.tabulate( - random.PRNGKey(0), + random.key(0), x=jnp.ones((32, 128, 64)), console_kwargs=CONSOLE_TEST_KWARGS, ) @@ -490,7 +490,7 @@ def __call__(self, x): x = jnp.ones((4, 28, 28, 32)) module_repr = CNN().tabulate( - jax.random.PRNGKey(0), + jax.random.key(0), x=x, show_repeated=True, console_kwargs=CONSOLE_TEST_KWARGS, @@ -575,7 +575,7 @@ def __call__(self, x): module = Classifier() lines = module.tabulate( - jax.random.PRNGKey(0), + jax.random.key(0), jnp.empty((1, 28, 28, 1)), console_kwargs=CONSOLE_TEST_KWARGS, ).splitlines() @@ -608,7 +608,7 @@ def __call__(self, x): x = jnp.ones((16, 9)) rep = Foo().tabulate( - jax.random.PRNGKey(0), x, console_kwargs=CONSOLE_TEST_KWARGS + jax.random.key(0), x, console_kwargs=CONSOLE_TEST_KWARGS ) lines = rep.splitlines() self.assertIn("Total Parameters: 50", lines[-2]) diff --git a/tests/linen/toplevel_test.py b/tests/linen/toplevel_test.py index 992d5c0c49..caf6c29038 100644 --- a/tests/linen/toplevel_test.py +++ b/tests/linen/toplevel_test.py @@ -48,11 +48,11 @@ class ModuleTopLevelTest(absltest.TestCase): # d = Dummy(parent=None).initialized() # def test_toplevel_initialized_with_rng(self): - # d = Dummy(parent=None).initialized(rngs={'params': random.PRNGKey(0)}) + # d = Dummy(parent=None).initialized(rngs={'params': random.key(0)}) # self.assertEqual(d.variables.param.foo, 1) # def test_toplevel_initialized_frozen(self): - # d = Dummy(parent=None).initialized(rngs={'params': random.PRNGKey(0)}) + # d = Dummy(parent=None).initialized(rngs={'params': random.key(0)}) # with self.assertRaisesRegex(BaseException, "Can't set value"): # d.variables.param.foo = 2 @@ -60,15 +60,15 @@ class ModuleTopLevelTest(absltest.TestCase): # d = Dummy(parent=None) # # initializing should make a copy and not have any effect # # on `d` itself. - # d_initialized = d.initialized(rngs={'params': random.PRNGKey(0)}) + # d_initialized = d.initialized(rngs={'params': random.key(0)}) # # ... make sure that indeed `d` has no scope. # self.assertIsNone(d.scope) # def test_can_only_call_initialized_once(self): # d = Dummy(parent=None) - # d = d.initialized(rngs={'params': random.PRNGKey(0)}) + # d = d.initialized(rngs={'params': random.key(0)}) # with self.assertRaises(BaseException): - # d.initialized(rngs={'params': random.PRNGKey(0)}) + # d.initialized(rngs={'params': random.key(0)}) if __name__ == '__main__': diff --git a/tests/serialization_test.py b/tests/serialization_test.py index 02112bff36..521fbb1765 100644 --- a/tests/serialization_test.py +++ b/tests/serialization_test.py @@ -122,7 +122,7 @@ def test_pass_through_serialization(self): self.assertEqual(restored_box, expected_box) def test_model_serialization(self): - rng = random.PRNGKey(0) + rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) x = jnp.ones((1, 1), jnp.float32) initial_params = module.init(rng, x) @@ -153,7 +153,7 @@ def test_partial_serialization(self): self.assertEqual(add_one.args, restored_add_one.args) def test_optimizer_serialization(self): - rng = random.PRNGKey(0) + rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) x = jnp.ones((1, 1), jnp.float32) initial_params = module.init(rng, x) @@ -194,7 +194,7 @@ def __call__(self): state = self.variable('state', 'dummy', DummyDataClass.initializer, ()) state.value = state.value.replace(x=state.value.x + 1.0) - initial_variables = StatefulModule().init(random.PRNGKey(0)) + initial_variables = StatefulModule().init(random.key(0)) _, variables = StatefulModule().apply(initial_variables, mutable=['state']) serialized_state_dict = serialization.to_state_dict(variables) self.assertEqual(serialized_state_dict, {'state': {'dummy': {'x': 2.0}}}) @@ -401,7 +401,7 @@ def test_namedtuple_restore_legacy(self): self.assertEqual(x1, restored_x1) def test_model_serialization_to_bytes(self): - rng = random.PRNGKey(0) + rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) initial_params = module.init(rng, jnp.ones((1, 1), jnp.float32)) serialized_bytes = serialization.to_bytes(initial_params) @@ -409,7 +409,7 @@ def test_model_serialization_to_bytes(self): self.assertEqual(restored_params, initial_params) def test_optimizer_serialization_to_bytes(self): - rng = random.PRNGKey(0) + rng = random.key(0) module = nn.Dense(features=1, kernel_init=nn.initializers.ones_init()) initial_params = module.init(rng, jnp.ones((1, 1), jnp.float32)) # model = nn.Model(module, initial_params) @@ -524,7 +524,7 @@ def test_serialization_chunking3(self): def test_serialization_errors(self, target, wrong_target, msg): if target == 'original_params': x = jnp.ones((1, 28, 28, 1)) - rng = jax.random.PRNGKey(1) + rng = jax.random.key(1) original_module = OriginalModule() target = original_module.init(rng, x) wrong_module = WrongModule() @@ -532,7 +532,7 @@ def test_serialization_errors(self, target, wrong_target, msg): elif target == 'original_train_state': x = jnp.ones((1, 28, 28, 1)) - rng = jax.random.PRNGKey(1) + rng = jax.random.key(1) original_module = OriginalModule() original_params = original_module.init(rng, x) wrong_module = WrongModule() diff --git a/tests/traceback_util_test.py b/tests/traceback_util_test.py index 3a322fb8f2..ef6e327b7f 100644 --- a/tests/traceback_util_test.py +++ b/tests/traceback_util_test.py @@ -68,7 +68,7 @@ def __call__(self, x): traceback_util.hide_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'tracebackhide') - key = random.PRNGKey(0) + key = random.key(0) try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: @@ -105,7 +105,7 @@ def __call__(self, x): traceback_util.hide_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'remove_frames') - key = random.PRNGKey(0) + key = random.key(0) try: nn.jit(Test1)().init(key, jnp.ones((5, 3))) except ValueError as e: @@ -145,7 +145,7 @@ def __call__(self, x): raise ValueError('error here.') return x # pylint: disable=unreachable - key = random.PRNGKey(0) + key = random.key(0) traceback_util.show_flax_in_tracebacks() jax.config.update('jax_traceback_filtering', 'off')