Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy issues in tests and examples #1077

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
123 commits
Select commit Hold shift + click to select a range
b17c6dc
Remove mutable logger_factory used by examples
dantp-ai Feb 19, 2024
933b27a
Fix mypy issues
dantp-ai Feb 21, 2024
7947f53
Add type annotations to funcs signatures
dantp-ai Feb 23, 2024
1163494
Rename variable to resolve type conflict
dantp-ai Feb 23, 2024
aa0f131
Provide right input type to buffer methods
dantp-ai Feb 23, 2024
7c91575
Index sample from buffer according to cheatsheet recommendations
dantp-ai Feb 23, 2024
76fe01e
Type index to make mypy happy
dantp-ai Feb 23, 2024
33df04c
Make mypy happy and check for mask attribute before asserting tests
dantp-ai Feb 23, 2024
ba5d74e
Make mypy happy and specify union of types
dantp-ai Feb 23, 2024
8d03ea4
Make mypy happy and type ndarray
dantp-ai Feb 23, 2024
13ae7a8
Use recommended outer buffer indexing
dantp-ai Feb 23, 2024
f282d81
Access buffer attrs with __getattr__
dantp-ai Feb 23, 2024
2805967
Ignore mypy issue as it explicitly tests for invalid type
dantp-ai Feb 29, 2024
eaa4503
Make mypy happy & use explicit var typing and ignore
dantp-ai Feb 29, 2024
cb765b4
Ignore type on explicit error
dantp-ai Feb 29, 2024
1aacd10
Remove redundant assert
dantp-ai Feb 29, 2024
a6a210e
Merge branch 'master' into refactoring/mypy-issues-test-and-examples
dantp-ai Mar 3, 2024
ef9581b
Add typing to func args
dantp-ai Mar 3, 2024
0356d8d
Add type annotations to funcs
dantp-ai Mar 4, 2024
c48d50f
Fix mypy issues
dantp-ai Mar 5, 2024
ebacf99
Use DataclassPPrintMixin to print collector stats
dantp-ai Mar 6, 2024
b6accd4
Add type annotations to func
dantp-ai Mar 6, 2024
41a9d03
Merge branch 'master' into refactoring/mypy-issues-test-and-examples
dantp-ai Mar 6, 2024
52bb1e3
Add type annotations to funcs
dantp-ai Mar 6, 2024
14b769f
Fix policy annotation
dantp-ai Mar 6, 2024
59f9bb1
Add type annotations func
dantp-ai Mar 6, 2024
ba6df44
Type policy var to resolve mypy confusion
dantp-ai Mar 6, 2024
ebd140f
Fix mypy issues
dantp-ai Mar 6, 2024
bce45a7
Check env.spec for none before accessing attrs
dantp-ai Mar 6, 2024
1dcd319
Add type annotation using space info
dantp-ai Mar 6, 2024
48af507
Add type annotations to NoopResetEnv's methods and helper funcs
dantp-ai Mar 6, 2024
2ec7c67
Use only integers with Generator
dantp-ai Mar 6, 2024
e1a85fa
Use DataclassPPrintMixin to print collect stats
dantp-ai Mar 7, 2024
8d9e168
Respect mypy typing for vars/args
dantp-ai Mar 7, 2024
6f31ac1
Fix many mypy issues
dantp-ai Mar 7, 2024
21e3805
Rename var to resolve ambiguity for mypy
dantp-ai Mar 8, 2024
010395d
Fix mypy issues (see below)
dantp-ai Mar 8, 2024
ed70e82
Add type annotations to DummyDataset and FiniteEnv
dantp-ai Mar 8, 2024
678c1c3
Fix mypy issues
dantp-ai Mar 14, 2024
70e6dc1
Fix some mypy issues
dantp-ai Mar 14, 2024
2db9b20
Fix many mypy issues related to:
dantp-ai Mar 15, 2024
b4d9450
Fix mypy issues:
dantp-ai Mar 15, 2024
62b884d
ignore mypy check
dantp-ai Mar 15, 2024
2b4ffa7
Fix mypy issues:
dantp-ai Mar 16, 2024
2fb1d95
Add missing type annotations
dantp-ai Mar 16, 2024
0398ef7
Fix some mypy issues:
dantp-ai Mar 16, 2024
f47b4ad
Merge branch 'master' into refactoring/mypy-issues-test-and-examples
dantp-ai Mar 16, 2024
cc04d08
Add missing type annotations to funcs
dantp-ai Mar 16, 2024
d265b2f
Fix mypy issues
dantp-ai Mar 16, 2024
48af7a5
Pass correct typed param env_fns to SubprocVectorEnv
dantp-ai Mar 16, 2024
4330e33
Fix more mypy issues:
dantp-ai Mar 16, 2024
91a3ee2
Bugfix: Tuple item assignment
dantp-ai Mar 17, 2024
a481e2f
Make mypy happy and use [] instea of . notation
dantp-ai Mar 18, 2024
5a68a9a
Merge branch 'master' into refactoring/mypy-issues-test-and-examples
dantp-ai Mar 20, 2024
c6aa77b
Typing: extend index type
Mar 21, 2024
0de06a4
Typing in tests: added asserts and cast to remove some mypy errors
Mar 21, 2024
0596f04
Revert output type to tuple[InfoStats, BasePolicy]
dantp-ai Mar 21, 2024
3031c8f
For mypy: Store obs/obs_next in new var and assert type
dantp-ai Mar 22, 2024
5451119
For mypy: store intermediate vars & assert type
dantp-ai Mar 22, 2024
a797811
Use pprint_asdict() instead
dantp-ai Mar 23, 2024
b09b581
Check for none
dantp-ai Mar 23, 2024
989633e
Merge branch 'master' into refactoring/mypy-issues-test-and-examples
dantp-ai Mar 23, 2024
41748f1
Use pprint_asdict to print CollectStats
dantp-ai Mar 23, 2024
12d5e68
Ignore type as mypy doesn't know that it should be wrong here
dantp-ai Mar 23, 2024
2fdab27
Ignore type annotation for step() because env can generate non-scalar…
dantp-ai Mar 23, 2024
bda6b10
Make mypy understand what type weight is
dantp-ai Mar 23, 2024
c4f6c2d
Make mypy happy and use proper type for adding constant to batch
dantp-ai Mar 24, 2024
2d6f5f1
Use setattr/getattr because mypy doesn't know
dantp-ai Mar 24, 2024
f5084ca
Remove unused var
dantp-ai Mar 24, 2024
2e8f3c3
Use SpaceInfo to determne types action/obs space
dantp-ai Mar 24, 2024
176ecc0
Revert "Use SpaceInfo to determne types action/obs space"
dantp-ai Mar 24, 2024
444b464
Cover case when state is dict
dantp-ai Mar 24, 2024
01aad1a
Set integer default value for batch_size
dantp-ai Mar 24, 2024
06ae70c
Add typing to env methods and use Gym API >v0.26 (with terminated, tr…
dantp-ai Mar 24, 2024
aba1d8c
Treat case when cpu_count() is None
dantp-ai Mar 24, 2024
92e19be
Fix mypy issues:
dantp-ai Mar 24, 2024
5ae29df
Ignore mypy typing on lines that use old Gym API:
dantp-ai Mar 25, 2024
65c4aa1
Make mypy happy and add type to var
dantp-ai Mar 25, 2024
34d94bc
Use space_info to type env spaces
dantp-ai Mar 25, 2024
216ba7c
Use assert instead of cast to check for obs_space/action_space:
dantp-ai Mar 25, 2024
524f0af
Assert action space
dantp-ai Mar 25, 2024
fe08d6e
Use assert as mypy doesn't know that FetchReach env has compute_rewar…
dantp-ai Mar 26, 2024
7a5071c
Check for none before comparing mean_rewards to reward_threshold
dantp-ai Mar 26, 2024
c8e5448
Assert action space before accessing attributes specific to that space
dantp-ai Mar 26, 2024
f71ad85
Refactor way DQN API is used:
dantp-ai Mar 26, 2024
56608da
Fix FiniteVectorEnv.reset() to satisfy superclass type annotations
dantp-ai Mar 26, 2024
10352d4
Fix mypy issues for AtariWrappers
dantp-ai Mar 26, 2024
c21586a
Make mypy happy and use typing for obs_space_dtype
dantp-ai Mar 27, 2024
75bb1f0
Assert env.action_space = MultiDiscrete:
dantp-ai Mar 27, 2024
a0d1427
Extend IndexType to explicitly have list[int]:
dantp-ai Mar 27, 2024
66d92af
Use SpaceInfo to type env obs/action space:
dantp-ai Mar 27, 2024
d0c8745
Use SpaceInfo to type env obs/action space:
dantp-ai Mar 27, 2024
d4b5d23
Assert action space before accessing space-specific attrs
dantp-ai Mar 27, 2024
2255f6e
Add type hints to obs/action spaces:
dantp-ai Mar 27, 2024
8c5fa77
Add type hints to action/obs spaces:
dantp-ai Mar 27, 2024
069a2e6
Minor reformat comments
dantp-ai Mar 28, 2024
d01420c
Merge branch 'master' into refactoring/mypy-issues-test-and-examples
dantp-ai Mar 28, 2024
feb0ee1
Merge branch 'thuml-master' into refactoring/mypy-issues-test-and-exa…
Apr 1, 2024
dab400f
Post-resolving conflicts in tests and examples
Apr 1, 2024
aaa56ea
Pyproject: added stubs, included tests and examples in type-check
Apr 1, 2024
32951b2
Extended pre-commit type check [skip ci]
Apr 1, 2024
6235a37
Typo in test
Apr 1, 2024
4a867bd
Remove type in IndexType for batch
dantp-ai Apr 2, 2024
528eb10
Use assert hasattr instead of getattr
dantp-ai Apr 2, 2024
d10b8b2
Remove iter since Batch already implements __iter__
dantp-ai Apr 2, 2024
5ed5d50
Use Literal instead of asserting members of list
dantp-ai Apr 2, 2024
7a18c4d
Use stop_fn for running this example
dantp-ai Apr 2, 2024
4fb294d
Use ValueError to inform user about what type of env is supported.
dantp-ai Apr 2, 2024
2ace3ef
Use more specific type hint for policy to get access to policy-specif…
dantp-ai Apr 2, 2024
6e50389
Refactor type annotation make_vizdoom_env
dantp-ai Apr 2, 2024
1ac275b
Use kw-args for better readability
dantp-ai Apr 2, 2024
322b6aa
Use os.path.join
dantp-ai Apr 2, 2024
0e2babb
Remove if and assert hasattr beforehand
dantp-ai Apr 2, 2024
71a4006
Return non-empty dict when reset
dantp-ai Apr 2, 2024
10669a6
Refactor ActorFactoryAtariDQN hidden_size semantics and output_dim of…
dantp-ai Apr 2, 2024
5567ce0
Refactor type annotations of scale_obs:
dantp-ai Apr 2, 2024
1c79d19
Simplify checks of obs_shape for atari envs
dantp-ai Apr 2, 2024
2b6722f
Use kw for input arguments to QRDQN
dantp-ai Apr 2, 2024
e8ba5ad
Made NetBase generic (explanation below), removed **kwargs from forward
Apr 3, 2024
e4d7d2f
SamplingConfig: support for batch_size=None
Apr 3, 2024
c1a4b40
Changelog [skip ci]
Apr 3, 2024
38b6b11
Merge branch 'thuml-master' into refactoring/mypy-issues-test-and-exa…
Apr 3, 2024
4c34a45
Changelog [skip ci]
Apr 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repos:
pass_filenames: false
- id: mypy
name: mypy
entry: poetry run mypy tianshou
entry: poetry run mypy tianshou examples test
# filenames should not be passed as they would collide with the config in pyproject.toml
pass_filenames: false
files: '^tianshou(/[^/]*)*/[^/]*\.py$'
Expand Down
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
- `SamplingConfig` supports `batch_size=None`. #1077

### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
Expand All @@ -20,6 +21,8 @@ instead of just `nn.Module`. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
- Exception no longer raised on `len` of empty `Batch`. #1084
- tests and examples are covered by `mypy`. #1077
- `NetBase` is more used, stricter typing by making it generic. #1077

### Breaking Changes

Expand All @@ -30,10 +33,10 @@ expicitly or pass `reset_before_collect=True` . #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
- `utils.net.common.Recurrent` now receives and returns a `RecurrentStateBatch` instead of a dict. #1077

### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081


Started after v1.0.0

15 changes: 7 additions & 8 deletions docs/02_notebooks/L0_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
},
{
"cell_type": "code",
"outputs": [],
"source": [
"# !pip install tianshou gym"
],
"execution_count": null,
"metadata": {
"collapsed": false
},
"execution_count": 0
"outputs": [],
"source": [
"# !pip install tianshou gym"
]
},
{
"cell_type": "markdown",
Expand Down Expand Up @@ -71,7 +71,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PPOPolicy\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import ActorCritic, Net\n",
"from tianshou.utils.net.discrete import Actor, Critic\n",
Expand Down Expand Up @@ -106,8 +106,7 @@
"\n",
"# PPO policy\n",
"dist = torch.distributions.Categorical\n",
"policy: BasePolicy\n",
"policy = PPOPolicy(\n",
"policy: PPOPolicy = PPOPolicy(\n",
" actor=actor,\n",
" critic=critic,\n",
" optim=optim,\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L5_Collector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PGPolicy\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
]
Expand All @@ -87,8 +87,7 @@
"actor = Actor(net, env.action_space.n)\n",
"optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n",
"\n",
"policy: BasePolicy\n",
"policy = PGPolicy(\n",
"policy: PGPolicy = PGPolicy(\n",
" actor=actor,\n",
" optim=optim,\n",
" dist_fn=torch.distributions.Categorical,\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L6_Trainer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PGPolicy\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
Expand Down Expand Up @@ -110,9 +110,8 @@
"actor = Actor(net, env.action_space.n)\n",
"optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
"\n",
"policy: BasePolicy\n",
"# We choose to use REINFORCE algorithm, also known as Policy Gradient\n",
"policy = PGPolicy(\n",
"policy: PGPolicy = PGPolicy(\n",
" actor=actor,\n",
" optim=optim,\n",
" dist_fn=torch.distributions.Categorical,\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/02_notebooks/L7_Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import BasePolicy, PPOPolicy\n",
"from tianshou.policy import PPOPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import ActorCritic, Net\n",
"from tianshou.utils.net.discrete import Actor, Critic\n",
Expand Down Expand Up @@ -164,8 +164,7 @@
"outputs": [],
"source": [
"dist = torch.distributions.Categorical\n",
"policy: BasePolicy\n",
"policy = PPOPolicy(\n",
"policy: PPOPolicy = PPOPolicy(\n",
" actor=actor,\n",
" critic=critic,\n",
" optim=optim,\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import C51
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import C51Policy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
Expand Down Expand Up @@ -122,6 +122,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -182,8 +183,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
9 changes: 5 additions & 4 deletions examples/atari/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import DQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.policy.modelbased.icm import ICMPolicy
Expand Down Expand Up @@ -104,7 +104,8 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
net = DQN(*args.state_shape, args.action_shape, args.device).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy: DQNPolicy = DQNPolicy(
policy: DQNPolicy | ICMPolicy
dantp-ai marked this conversation as resolved.
Show resolved Hide resolved
policy = DQNPolicy(
model=net,
optim=optim,
action_space=env.action_space,
Expand Down Expand Up @@ -157,6 +158,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
MischaPanch marked this conversation as resolved.
Show resolved Hide resolved
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -223,8 +225,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_fqf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import FQFPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
Expand Down Expand Up @@ -135,6 +135,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -195,8 +196,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
6 changes: 3 additions & 3 deletions examples/atari/atari_iqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from atari_network import DQN
from atari_wrapper import make_atari_env

from examples.common import logger_factory
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import IQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
Expand Down Expand Up @@ -132,6 +132,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
log_path = os.path.join(args.logdir, log_name)

# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
Expand Down Expand Up @@ -192,8 +193,7 @@ def watch() -> None:
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
rew = result.returns_stat.mean
print(f"Mean reward (over {result['n/ep']} episodes): {rew}")
result.pprint_asdict()

if args.watch:
watch()
Expand Down
Loading