Skip to content

Commit

Permalink
Remove unecessary jit from 03_eval_finetuned (fixes #22)
Browse files Browse the repository at this point in the history
  • Loading branch information
kvablack authored Dec 27, 2023
1 parent bcaddd0 commit d6267cd
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions examples/03_eval_finetuned.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def main(_):
env, model.dataset_statistics, normalization_type="normal"
)

# jit model action prediction function for faster inference
policy_fn = jax.jit(model.sample_actions)

# running rollouts
for _ in range(3):
obs, info = env.reset()
Expand All @@ -77,7 +74,7 @@ def main(_):
episode_return = 0.0
while len(images) < 400:
# model returns actions of shape [batch, pred_horizon, action_dim] -- remove batch
actions = policy_fn(
actions = model.sample_actions(
jax.tree_map(lambda x: x[None], obs), task, rng=jax.random.PRNGKey(0)
)
actions = actions[0]
Expand Down

0 comments on commit d6267cd

Please sign in to comment.