Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
tomtseng committed Jan 7, 2025
1 parent f64b3fc commit 5820cfe
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/imitation/algorithms/adversarial/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(
self.debug_use_ground_truth = debug_use_ground_truth
self.venv = venv
self.gen_algo = gen_algo
self._reward_net = reward_net.to(gen_algo.device)
self._reward_net: reward_nets.RewardNet = reward_net.to(gen_algo.device)
self._log_dir = util.parse_path(log_dir)

# Create graph for optimising/recording stats on discriminator
Expand Down
2 changes: 1 addition & 1 deletion src/imitation/policies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _predict(
):
np_actions = []
if isinstance(obs, dict):
np_obs = types.DictObs(
np_obs: Union[types.DictObs, np.ndarray] = types.DictObs(
{k: v.detach().cpu().numpy() for k, v in obs.items()},
)
else:
Expand Down

0 comments on commit 5820cfe

Please sign in to comment.