Skip to content

Commit

Permalink
fix: Change to new-style-jax-rng-keys. (#195)
Browse files Browse the repository at this point in the history
* fix: Change to new-style-jax-rng-keys.

---------

Co-authored-by: Milton Montero <lleramontero@gmail.com>
  • Loading branch information
Lookatator and miltonllera authored Sep 21, 2024
1 parent 5ee7823 commit 6656f5e
Show file tree
Hide file tree
Showing 65 changed files with 124 additions and 96 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ min_bd = 0.0
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key = jax.random.key(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
Expand Down
2 changes: 1 addition & 1 deletion examples/aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/cmaes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
"outputs": [],
"source": [
"state = cmaes.init()\n",
"random_key = jax.random.PRNGKey(0)"
"random_key = jax.random.key(0)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion examples/cmame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.random.key(0)\n",
"# in CMA-ME settings (from the paper), there is no init population\n",
"# we multipy by zero to reproduce this setting\n",
"initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/cmamega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.key(0)\n",
"# no initial population - give all the same value as emitter init value\n",
"initial_population = jax.random.uniform(random_key, shape=(batch_size, num_dimensions)) * 0.\n",
"\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/dads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -499,7 +499,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"random_key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/dcrlme.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
"source": [
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init environment\n",
"env = environments.create(env_name, episode_length=episode_length)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/diayn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -490,7 +490,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"random_key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/distributed_mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/jumanji_snake.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"env = jumanji.make('Snake-v1')\n",
"\n",
"# Reset your (jit-able) environment\n",
"key = jax.random.PRNGKey(0)\n",
"key = jax.random.key(0)\n",
"state, timestep = jax.jit(env.reset)(key)\n",
"\n",
"# Interact with the (jit-able) environment\n",
Expand All @@ -137,7 +137,7 @@
"outputs": [],
"source": [
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# get number of actions\n",
"num_actions = env.action_spec().maximum + 1\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/mapelites.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.Key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down Expand Up @@ -494,7 +494,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=rng)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
8 changes: 6 additions & 2 deletions examples/me_sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
"outputs": [],
"source": [
"%%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, subkey = jax.random.split(key)\n",
"env_states = jax.jit(env.reset)(rng=subkey)\n",
"eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)"
Expand Down Expand Up @@ -311,6 +311,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)"
]
},
Expand Down Expand Up @@ -504,7 +508,7 @@
"%%time\n",
"rollout = []\n",
"\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"env_state = jax.jit(env.reset)(rng=rng)\n",
"\n",
"training_state, env_state = jax.tree_map(\n",
Expand Down
6 changes: 5 additions & 1 deletion examples/me_td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@
"outputs": [],
"source": [
"%%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, subkey = jax.random.split(key)\n",
"env_states = jax.jit(env.reset)(rng=subkey)\n",
"eval_env_first_states = jax.jit(eval_env.reset)(rng=subkey)"
Expand Down Expand Up @@ -314,6 +314,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, _ = jax.pmap(agent_init_fn, axis_name=\"p\", devices=devices)(keys)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion examples/mees.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
4 changes: 2 additions & 2 deletions examples/mels.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down Expand Up @@ -509,7 +509,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=rng)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/mome.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"outputs": [],
"source": [
"# initial population\n",
"random_key = jax.random.PRNGKey(42)\n",
"random_key = jax.random.key(42)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"genotypes = jax.random.uniform(\n",
" random_key, (batch_size, num_variables), minval=minval, maxval=maxval, dtype=jnp.float32\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/nsga2_spea2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
"outputs": [],
"source": [
"# Initial population\n",
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.random.key(0)\n",
"random_key, subkey = jax.random.split(random_key)\n",
"genotypes = jax.random.uniform(\n",
" subkey, (batch_size, genotype_dim), minval=minval, maxval=maxval, dtype=jnp.float32\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/omgmega.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@
"metadata": {},
"outputs": [],
"source": [
"random_key = jax.random.PRNGKey(0)\n",
"random_key = jax.random.key(0)\n",
"\n",
"# defines the population\n",
"random_key, subkey = jax.random.split(random_key)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/pga_aurora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/pgame.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/qdpg.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
"env = environments.create(env_name, episode_length=episode_length)\n",
"\n",
"# Init a random key\n",
"random_key = jax.random.PRNGKey(seed)\n",
"random_key = jax.random.key(seed)\n",
"\n",
"# Init policy network\n",
"policy_layer_sizes = policy_hidden_layer_sizes + (env.action_size,)\n",
Expand Down
8 changes: 6 additions & 2 deletions examples/sac_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@
"outputs": [],
"source": [
"# %%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, *keys = jax.random.split(key, num=1 + num_devices)\n",
"keys = jnp.stack(keys)\n",
"env_states, eval_env_first_states = jax.pmap(\n",
Expand Down Expand Up @@ -269,6 +269,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, replay_buffers = jax.pmap(\n",
" agent_init_fn, axis_name=\"p\", devices=devices\n",
")(keys)"
Expand Down Expand Up @@ -518,7 +522,7 @@
"%%time\n",
"rollout = []\n",
"\n",
"rng = jax.random.PRNGKey(seed=1)\n",
"rng = jax.random.key(seed=1)\n",
"env_state = jax.jit(env.reset)(rng=rng)\n",
"\n",
"training_state, env_state = jax.tree_util.tree_map(\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/me_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def run_me() -> None:
max_bd = 1.0

# Init a random key
random_key = jax.random.PRNGKey(seed)
random_key = jax.random.key(seed)

# Init population of controllers
random_key, subkey = jax.random.split(random_key)
Expand Down
4 changes: 2 additions & 2 deletions examples/smerl.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@
" eval_metrics=True,\n",
")\n",
"\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.Key(seed)\n",
"env_state = jax.jit(env.reset)(rng=key)\n",
"eval_env_first_state = jax.jit(eval_env.reset)(rng=key)\n",
"\n",
Expand Down Expand Up @@ -504,7 +504,7 @@
"outputs": [],
"source": [
"rollout = []\n",
"random_key = jax.random.PRNGKey(seed=1)\n",
"random_key = jax.random.key(seed=1)\n",
"state = jit_env_reset(rng=random_key)\n",
"while not state.done:\n",
" rollout.append(state)\n",
Expand Down
6 changes: 5 additions & 1 deletion examples/td3_pbt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@
"outputs": [],
"source": [
"%%time\n",
"key = jax.random.PRNGKey(seed)\n",
"key = jax.random.key(seed)\n",
"key, *keys = jax.random.split(key, num=1 + num_devices)\n",
"keys = jnp.stack(keys)\n",
"env_states, eval_env_first_states = jax.pmap(\n",
Expand Down Expand Up @@ -232,6 +232,10 @@
" observation_size=env.observation_size,\n",
" buffer_size=buffer_size,\n",
")\n",
"\n",
"# Need to convert to PRNGKey because of github.com/jax-ml/jax/issues/23647\n",
"keys = jax.random.key_data(keys)\n",
"\n",
"keys, training_states, replay_buffers = jax.pmap(\n",
" agent_init_fn, axis_name=\"p\", devices=devices\n",
")(keys)"
Expand Down
2 changes: 1 addition & 1 deletion qdax/core/containers/mapelites_repertoire.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def compute_cvt_centroids(
init="k-means++",
n_clusters=num_centroids,
n_init=1,
random_state=RandomState(subkey),
random_state=RandomState(jax.random.key_data(subkey)),
)
k_means.fit(x)
centroids = k_means.cluster_centers_
Expand Down
Loading

0 comments on commit 6656f5e

Please sign in to comment.