diff --git a/.gitignore b/.gitignore index f3365c6d3..e3e32e85c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,7 @@ *.py~ *.bak .pytest_cache +.mypy_cache .DS_Store .idea .vscode diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index dc09ed242..ad9bfaa05 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -2,7 +2,7 @@ image: stablebaselines/stable-baselines3-cpu:1.4.1a0 type-check: script: - - pip install pytype --upgrade + - pip install pytype mypy --upgrade - make type pytest: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 67d794d80..89b15b72e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -78,7 +78,7 @@ To run tests with `pytest`: make pytest ``` -Type checking with `pytype`: +Type checking with `pytype` and `mypy`: ``` make type @@ -91,7 +91,7 @@ make check-codestyle make lint ``` -To run `pytype`, `format` and `lint` in one command: +To run `type`, `format` and `lint` in one command: ``` make commit-checks ``` diff --git a/Makefile b/Makefile index 9954c7d7b..3e4e73456 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,14 @@ LINT_PATHS=stable_baselines3/ tests/ docs/conf.py setup.py pytest: ./scripts/run_tests.sh -type: +pytype: pytype -j auto +mypy: + MYPY_FORCE_COLOR=1 mypy ${LINT_PATHS} + +type: pytype mypy + lint: # stop the build if there are Python syntax errors or undefined names # see https://lintlyci.github.io/Flake8Rules/ diff --git a/README.md b/README.md index 78bcc672f..26b796cd0 100644 --- a/README.md +++ b/README.md @@ -198,9 +198,9 @@ pip install pytest pytest-cov make pytest ``` -You can also do a static type check using `pytype`: +You can also do a static type check using `pytype` and `mypy`: ``` -pip install pytype +pip install pytype mypy make type ``` diff --git a/docs/conf.py b/docs/conf.py index b44be6f66..672db5554 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -13,6 +13,7 @@ # import os import sys +from typing import Dict, List from unittest.mock import MagicMock # We CANNOT enable 'sphinxcontrib.spelling' because ReadTheDocs.org does not support @@ -37,7 +38,7 @@ class Mock(MagicMock): - __subclasses__ = [] + __subclasses__ = [] # type: ignore @classmethod def __getattr__(cls, name): @@ -48,7 +49,7 @@ def __getattr__(cls, name): # Note: because of that we cannot test examples using CI # 'torch', 'torch.nn', 'torch.nn.functional', # DO not mock modules for now, we will need to do that for read the docs later -MOCK_MODULES = [] +MOCK_MODULES: List[str] = [] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Read version from file @@ -171,7 +172,7 @@ def setup(app): # -- Options for LaTeX output ------------------------------------------------ -latex_elements = { +latex_elements: Dict[str, str] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5d2227f66..3ef78d5ce 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -16,6 +16,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Introduced mypy type checking SB3-Contrib ^^^^^^^^^^^ diff --git a/setup.cfg b/setup.cfg index eb3b04c71..6cf08b3e6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,60 @@ markers = inputs = stable_baselines3 disable = pyi-error +[mypy] +ignore_missing_imports = True +follow_imports = silent +show_error_codes = True +exclude = (?x)( + stable_baselines3/a2c/a2c.py$ + | stable_baselines3/common/atari_wrappers.py$ + | stable_baselines3/common/base_class.py$ + | stable_baselines3/common/buffers.py$ + | stable_baselines3/common/callbacks.py$ + | stable_baselines3/common/distributions.py$ + | stable_baselines3/common/env_util.py$ + | stable_baselines3/common/envs/bit_flipping_env.py$ + | stable_baselines3/common/envs/identity_env.py$ + | stable_baselines3/common/envs/multi_input_envs.py$ + | stable_baselines3/common/logger.py$ + | stable_baselines3/common/monitor.py$ + | stable_baselines3/common/off_policy_algorithm.py$ + | stable_baselines3/common/on_policy_algorithm.py$ + | stable_baselines3/common/policies.py$ + | stable_baselines3/common/preprocessing.py$ + | stable_baselines3/common/save_util.py$ + | stable_baselines3/common/sb2_compat/rmsprop_tf_like.py$ + | stable_baselines3/common/torch_layers.py$ + | stable_baselines3/common/type_aliases.py$ + | stable_baselines3/common/utils.py$ + | stable_baselines3/common/vec_env/__init__.py$ + | stable_baselines3/common/vec_env/base_vec_env.py$ + | stable_baselines3/common/vec_env/dummy_vec_env.py$ + | stable_baselines3/common/vec_env/stacked_observations.py$ + | stable_baselines3/common/vec_env/subproc_vec_env.py$ + | stable_baselines3/common/vec_env/util.py$ + | stable_baselines3/common/vec_env/vec_check_nan.py$ + | stable_baselines3/common/vec_env/vec_extract_dict_obs.py$ + | stable_baselines3/common/vec_env/vec_frame_stack.py$ + | stable_baselines3/common/vec_env/vec_monitor.py$ + | stable_baselines3/common/vec_env/vec_normalize.py$ + | stable_baselines3/common/vec_env/vec_transpose.py$ + | stable_baselines3/common/vec_env/vec_video_recorder.py$ + | stable_baselines3/dqn/dqn.py$ + | stable_baselines3/dqn/policies.py$ + | stable_baselines3/her/her_replay_buffer.py$ + | stable_baselines3/ppo/ppo.py$ + | stable_baselines3/sac/policies.py$ + | stable_baselines3/sac/sac.py$ + | stable_baselines3/td3/policies.py$ + | stable_baselines3/td3/td3.py$ + | tests/test_distributions.py$ + | tests/test_logger.py$ + | tests/test_tensorboard.py$ + | tests/test_train_eval_mode.py$ + | tests/test_vec_normalize.py$ + ) + [flake8] ignore = W503,W504,E203,E231 # line breaks before and after binary operators # Ignore import not used when aliases are defined diff --git a/setup.py b/setup.py index bfcb56ca1..44bcb158d 100644 --- a/setup.py +++ b/setup.py @@ -95,6 +95,7 @@ "pytest-xdist", # Type check "pytype", + "mypy", # Lint code "flake8>=3.8", # Find likely bugs diff --git a/tests/test_utils.py b/tests/test_utils.py index 34db00e0a..7b228f0aa 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -149,6 +149,7 @@ def test_evaluate_policy(direct_policy: bool): def dummy_callback(locals_, _globals): locals_["model"].n_callback_calls += 1 + assert model.policy is not None policy = model.policy if direct_policy else model policy.n_callback_calls = 0 _, episode_lengths = evaluate_policy(