Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy and static type checking to Habitat Lab #492

Merged
merged 27 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4f9a1b8
typed Habitat Core
Skylion007 Sep 30, 2020
0594c7d
Fix more habitat typing issues
Skylion007 Oct 3, 2020
4fdb35e
Fix objnav typecheck
Skylion007 Oct 4, 2020
ab19922
Fix incorrect config type
Skylion007 Oct 4, 2020
3be5765
complete habitat lab mypy conversion
Skylion007 Oct 4, 2020
d5ea553
Migrate mypy config to mypy.ini
Skylion007 Oct 4, 2020
78c7ca8
Remove unused mypy config values
Skylion007 Oct 4, 2020
f3dde70
Re-enable episodes setter
Skylion007 Oct 4, 2020
d1ff0dd
Finish typing habitat_baselines
Skylion007 Oct 5, 2020
02aae0c
Stricten type hints in habitat.core.simulator
Skylion007 Oct 7, 2020
cbdc158
Fix registry type add constructor to Simulator
Skylion007 Oct 7, 2020
5e6a350
Fix import error
Skylion007 Oct 7, 2020
b38f02c
Make max singleton
Skylion007 Oct 7, 2020
ea2830a
Remove list copy
Skylion007 Oct 7, 2020
90ec223
Merge branch 'master' of https://github.com/facebookresearch/habitat-…
Skylion007 Oct 8, 2020
17a6e69
Update mypy with global ignores
Skylion007 Oct 10, 2020
99ee6cc
Do better list casting
Skylion007 Oct 10, 2020
b9070f9
Update pre-commit hooks
Skylion007 Oct 10, 2020
796052b
unify pre-commit style
Skylion007 Oct 10, 2020
d16c406
Add check-yaml pre-commit hook
Skylion007 Oct 10, 2020
843d1e2
Bugfix and optimize ci tests
Skylion007 Oct 10, 2020
8c2001a
Remove one more type ignore
Skylion007 Oct 10, 2020
16c5fbe
Update habitat_baselines/rl/ppo/ppo.py
Skylion007 Oct 10, 2020
d8f925c
Fix ppo imports
Skylion007 Oct 10, 2020
acf9fd8
remove dead code
Skylion007 Oct 12, 2020
304d378
Revert "remove dead code"
Skylion007 Oct 12, 2020
8580b03
Merge branch 'master' of https://github.com/facebookresearch/habitat-…
Skylion007 Oct 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ jobs:
name: Build, install habitat-sim and run benchmark
no_output_timeout: 20m
command: |
while [ ! -f ./cuda_installed ]; do sleep 2; done # wait for CUDA
export PATH=$HOME/miniconda/bin:/usr/local/cuda/bin:$PATH
. activate habitat;
if [ ! -d ./habitat-sim ]
then
git clone https://github.com/facebookresearch/habitat-sim.git
git clone https://github.com/facebookresearch/habitat-sim.git --recursive
fi
while [ ! -f ./cuda_installed ]; do sleep 2; done # wait for CUDA
export PATH=$HOME/miniconda/bin:/usr/local/cuda/bin:$PATH
. activate habitat;
cd habitat-sim
pip install -r requirements.txt --progress-bar off
python setup.py install --headless --with-cuda --bullet
Expand Down
20 changes: 14 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: check-added-large-files
Expand All @@ -17,11 +17,13 @@ repos:
- id: check-executables-have-shebangs
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
exclude: "habitat_baselines/slambased/data/"
- id: mixed-line-ending
args: ['--fix=lf']

- repo: https://github.com/timothycrosley/isort
rev: 5.4.2
rev: 5.6.2
hooks:
- id: isort
exclude: docs/
Expand All @@ -34,23 +36,29 @@ repos:
exclude: ^examples/tutorials/(nb_python|colabs)

- repo: https://github.com/myint/autoflake
rev: master
rev: v1.4
hooks:
- id: autoflake
args: ['--expand-star-imports', '--ignore-init-module-imports', '--in-place', '-c']
args: ['--expand-star-imports', '--ignore-init-module-imports', '--in-place']
exclude: docs/

- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
rev: 3.8.4
hooks:
- id: flake8
exclude: docs/
additional_dependencies:
- flake8-bugbear
- flake8-comprehensions

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.790
hooks:
- id: mypy
pass_filenames: false

- repo: https://github.com/kynan/nbstripout
rev: master
rev: 0.3.9
hooks:
- id: nbstripout
files: ".ipynb"
Expand Down
1 change: 0 additions & 1 deletion habitat/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import yacs.config

# from habitat.config import Config as CN # type: ignore

# Default Habitat config node
class Config(yacs.config.CfgNode):
Expand Down
7 changes: 4 additions & 3 deletions habitat/core/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
and ``reset()`` methods.
"""

from typing import Any, Dict, Union
from typing import TYPE_CHECKING, Any, Dict, Union

from habitat.core.simulator import Observations
if TYPE_CHECKING:
from habitat.core.simulator import Observations


class Agent:
Expand All @@ -24,7 +25,7 @@ def reset(self) -> None:
raise NotImplementedError

def act(
self, observations: Observations
self, observations: "Observations"
) -> Union[int, str, Dict[str, Any]]:
r"""Called to produce an action to perform in an environment.

Expand Down
16 changes: 10 additions & 6 deletions habitat/core/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@

import os
from collections import defaultdict
from typing import Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional

from habitat.config.default import get_config
from habitat.core.agent import Agent
from habitat.core.env import Env

if TYPE_CHECKING:
from habitat.core.agent import Agent


class Benchmark:
r"""Benchmark for evaluating agents in environments."""

def __init__(
self, config_paths: Optional[str] = None, eval_remote=False
self, config_paths: Optional[str] = None, eval_remote: bool = False
) -> None:
r"""..

Expand All @@ -38,7 +40,7 @@ def __init__(
self._env = Env(config=config_env)

def remote_evaluate(
self, agent: Agent, num_episodes: Optional[int] = None
self, agent: "Agent", num_episodes: Optional[int] = None
):
# The modules imported below are specific to habitat-challenge remote evaluation.
# These modules are not part of the habitat-lab repository.
Expand Down Expand Up @@ -113,7 +115,9 @@ def remote_ep_over(stub):

return avg_metrics

def local_evaluate(self, agent: Agent, num_episodes: Optional[int] = None):
def local_evaluate(
self, agent: "Agent", num_episodes: Optional[int] = None
) -> Dict[str, float]:
if num_episodes is None:
num_episodes = len(self._env.episodes)
else:
Expand Down Expand Up @@ -147,7 +151,7 @@ def local_evaluate(self, agent: Agent, num_episodes: Optional[int] = None):
return avg_metrics

def evaluate(
self, agent: Agent, num_episodes: Optional[int] = None
self, agent: "Agent", num_episodes: Optional[int] = None
) -> Dict[str, float]:
r"""..

Expand Down
40 changes: 24 additions & 16 deletions habitat/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
)

import attr
import numpy as np
from numpy import ndarray

from habitat.config import Config
from habitat.core.utils import not_none_validator
Expand Down Expand Up @@ -58,7 +61,7 @@ class Episode:
start_rotation: List[float] = attr.ib(
default=None, validator=not_none_validator
)
info: Optional[Dict[str, str]] = None
info: Optional[Dict[str, Any]] = None
_shortest_path_cache: Any = attr.ib(init=False, default=None)

def __getstate__(self):
Expand Down Expand Up @@ -101,8 +104,8 @@ def get_scenes_to_load(cls, config: Config) -> List[str]:

:return: A list of scene names that would be loaded with the dataset
"""
assert cls.check_config_paths_exist(config)
dataset = cls(config)
assert cls.check_config_paths_exist(config) # type: ignore[attr-defined]
dataset = cls(config) # type: ignore[call-arg]
return list(map(cls.scene_from_scene_path, dataset.scene_ids))

@classmethod
Expand Down Expand Up @@ -279,7 +282,7 @@ def get_splits(
self.num_episodes, num_episodes, replace=False
)
if collate_scene_ids:
scene_ids = {}
scene_ids: Dict[str, List[int]] = {}
for rand_ind in rand_items:
scene = self.episodes[rand_ind].scene_id
if scene not in scene_ids:
Expand Down Expand Up @@ -334,7 +337,7 @@ class EpisodeIterator(Iterator):

def __init__(
self,
episodes: List[T],
episodes: Sequence[T],
cycle: bool = True,
shuffle: bool = False,
group_by_scene: bool = True,
Expand All @@ -343,7 +346,7 @@ def __init__(
num_episode_sample: int = -1,
step_repetition_range: float = 0.2,
seed: int = None,
):
) -> None:
r"""..

:param episodes: list of episodes.
Expand Down Expand Up @@ -375,6 +378,9 @@ def __init__(
episodes, num_episode_sample, replace=False
)

if not isinstance(episodes, list):
episodes = list(episodes)

self.episodes = episodes
self.cycle = cycle
self.group_by_scene = group_by_scene
Expand All @@ -391,17 +397,17 @@ def __init__(

self._rep_count = -1 # 0 corresponds to first episode already returned
self._step_count = 0
self._prev_scene_id = None
self._prev_scene_id: Optional[str] = None

self._iterator = iter(self.episodes)

self.step_repetition_range = step_repetition_range
self._set_shuffle_intervals()

def __iter__(self):
def __iter__(self) -> "EpisodeIterator":
return self

def __next__(self):
def __next__(self) -> Episode:
r"""The main logic for handling how episodes will be iterated.

:return: next episode.
Expand Down Expand Up @@ -459,7 +465,9 @@ def _shuffle(self) -> None:

self._iterator = iter(episodes)

def _group_scenes(self, episodes):
def _group_scenes(
self, episodes: Union[Sequence[Episode], List[Episode], ndarray]
) -> List[T]:
r"""Internal method that groups episodes by scene
Groups will be ordered by the order the first episode of a given
scene is in the list of episodes
Expand All @@ -469,23 +477,23 @@ def _group_scenes(self, episodes):
"""
assert self.group_by_scene

scene_sort_keys = {}
scene_sort_keys: Dict[str, int] = {}
for e in episodes:
if e.scene_id not in scene_sort_keys:
scene_sort_keys[e.scene_id] = len(scene_sort_keys)

return sorted(episodes, key=lambda e: scene_sort_keys[e.scene_id])
return sorted(episodes, key=lambda e: scene_sort_keys[e.scene_id]) # type: ignore[arg-type]

def step_taken(self):
def step_taken(self) -> None:
self._step_count += 1

@staticmethod
def _randomize_value(value, value_range):
def _randomize_value(value: int, value_range: float) -> int:
return random.randint(
int(value * (1 - value_range)), int(value * (1 + value_range))
)

def _set_shuffle_intervals(self):
def _set_shuffle_intervals(self) -> None:
if self.max_scene_repetition_episodes > 0:
self._max_rep_episode = self.max_scene_repetition_episodes
else:
Expand All @@ -498,7 +506,7 @@ def _set_shuffle_intervals(self):
else:
self._max_rep_step = None

def _forced_scene_switch_if(self):
def _forced_scene_switch_if(self) -> None:
do_switch = False
self._rep_count += 1

Expand Down
2 changes: 1 addition & 1 deletion habitat/core/embodied_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def reset(self, episode: Type[Episode]):

return observations

def step(self, action: Union[int, Dict[str, Any]], episode: Type[Episode]):
def step(self, action: Dict[str, Any], episode: Type[Episode]):
if "action_args" not in action or action["action_args"] is None:
action["action_args"] = {}
action_name = action["action"]
Expand Down
26 changes: 20 additions & 6 deletions habitat/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@

import random
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)

import gym
import numba
Expand Down Expand Up @@ -77,7 +87,11 @@ def __init__(
self._dataset = make_dataset(
id_dataset=config.DATASET.TYPE, config=config.DATASET
)
self._episodes = self._dataset.episodes if self._dataset else []
self._episodes = (
self._dataset.episodes
if self._dataset
else cast(List[Type[Episode]], [])
)
self._current_episode = None
iter_option_dict = {
k.lower(): v
Expand Down Expand Up @@ -342,14 +356,14 @@ def habitat_env(self) -> Env:
def episodes(self) -> List[Type[Episode]]:
return self._env.episodes

@property
def current_episode(self) -> Type[Episode]:
return self._env.current_episode

@episodes.setter
def episodes(self, episodes: List[Type[Episode]]) -> None:
self._env.episodes = episodes

@property
def current_episode(self) -> Type[Episode]:
return self._env.current_episode

def reset(self) -> Observations:
return self._env.reset()

Expand Down
4 changes: 2 additions & 2 deletions habitat/core/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(
):
super().__init__(name, level)
if filename is not None:
handler = logging.FileHandler(filename, filemode)
handler = logging.FileHandler(filename, filemode) # type:ignore
else:
handler = logging.StreamHandler(stream)
handler = logging.StreamHandler(stream) # type:ignore
self._formatter = logging.Formatter(format, dateformat, style)
handler.setFormatter(self._formatter)
super().addHandler(handler)
Expand Down
Loading