From 4ca691827a803a5fa6ab189cc5b35f3f95f767c3 Mon Sep 17 00:00:00 2001 From: rklassert Date: Mon, 15 Apr 2024 14:35:28 +0200 Subject: [PATCH] Fix Jumanji environment spec access in examples --- examples/anakin_dqn_example.ipynb | 2 +- examples/anakin_ppo_example.ipynb | 2 +- examples/anakin_prioritised_dqn_example.ipynb | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/anakin_dqn_example.ipynb b/examples/anakin_dqn_example.ipynb index 455830b..9c860ca 100644 --- a/examples/anakin_dqn_example.ipynb +++ b/examples/anakin_dqn_example.ipynb @@ -310,7 +310,7 @@ "):\n", " \"\"\"Sets up the experiment.\"\"\"\n", " cores_count = len(jax.devices()) # get available TPU cores.\n", - " network = get_network_fn(env.action_spec().num_values) # define network.\n", + " network = get_network_fn(env.action_spec.num_values) # define network.\n", " optim = optax.adam(step_size) # define optimiser.\n", "\n", " rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3) # prng keys.\n", diff --git a/examples/anakin_ppo_example.ipynb b/examples/anakin_ppo_example.ipynb index e989fe5..d6d632d 100644 --- a/examples/anakin_ppo_example.ipynb +++ b/examples/anakin_ppo_example.ipynb @@ -253,7 +253,7 @@ " \"\"\"Sets up the experiment and returns the necessary information.\"\"\"\n", "\n", " cores_count = len(jax.devices()) # get available TPU cores.\n", - " network = get_network_fn(env.action_spec().num_values) # define network.\n", + " network = get_network_fn(env.action_spec.num_values) # define network.\n", " optim = optax.adam(step_size) # define optimiser.\n", "\n", " rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3) # prng keys.\n", diff --git a/examples/anakin_prioritised_dqn_example.ipynb b/examples/anakin_prioritised_dqn_example.ipynb index 542ba94..377a54f 100644 --- a/examples/anakin_prioritised_dqn_example.ipynb +++ b/examples/anakin_prioritised_dqn_example.ipynb @@ -333,7 +333,7 @@ "):\n", " \"\"\"Sets up the experiment.\"\"\"\n", " cores_count = len(jax.devices()) # get available TPU cores.\n", - " network = get_network_fn(env.action_spec().num_values) # define network.\n", + " network = get_network_fn(env.action_spec.num_values) # define network.\n", " optim = optax.adam(step_size) # define optimiser.\n", "\n", " rng, rng_e, rng_p = random.split(random.PRNGKey(seed), num=3) # prng keys.\n",