Skip to content

Commit b375907

Browse files
Daraangemini-code-assist[bot]kamil-kaczmarek
authored
[RLlib] Overhaul of the typing module & better device typing (#55291)
Resolves: #55288 (wrong `np.array` in `TensorType`) Furthermore changes: - Changed comments to (semi)docstring which will be displayed as tooltips by IDEs (e.g. VSCode + Pylance) making that information available to the user. - `AgentID: Any -> Hashable` as it used for dict keys - changed `DeviceType` to be not a TypeVar (makes no sense in the way it is currently used), also includes DeviceLikeType (`int | str | device`) from `torch`. IMO it can fully replace the current type but being defensive I only added it as an extra possible type - Used updated DeviceType to improve type of Runner._device and make it more correct - Used torch's own type in `data`, current code supports more than just `str`. I refrained from adding a reference to `rllib` despite it being nice if they would be in sync. - Some extra formatting that is forced by pre-commit <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Revamps `rllib.utils.typing` (NDArray-based `TensorType`, broader `DeviceType`, `AgentID` as `Hashable`, docstring cleanups) and updates call sites to use optional device typing and improved hints. > > - **Types**: > - Overhaul `rllib/utils/typing.py`: > - `TensorType` now uses `numpy.typing.NDArray`; heavy use of `TYPE_CHECKING` to avoid runtime deps on torch/tf/jax. > - `DeviceType` widened to `Union[str, torch.device, int]` (was `TypeVar`). > - `AgentID` tightened to `Hashable`; `NetworkType` uses `keras.Model`. > - Refined aliases (e.g., `FromConfigSpec`, `SpaceStruct`) and added concise docstrings. > - **Runners**: > - `Runner._device` now `Optional` (`Union[DeviceType, None]`) with updated docstring; same change in offline runners’ `_device` properties. > - **Connectors**: > - `NumpyToTensor`: `device` param typed as `Optional[DeviceType]` (via `TYPE_CHECKING`). > - **Utils**: > - `from_config`: typed `config: Optional[FromConfigSpec]` with `TYPE_CHECKING` import. > - **Misc**: > - Minor formatting/import ordering and comment typo fixes. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit ae2e422. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Daniel Sperber <github.blurry@9ox.net> Signed-off-by: Daraan <github.blurry@9ox.net> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Kamil Kaczmarek <kaczmarek.poczta@gmail.com> Co-authored-by: Kamil Kaczmarek <kamil@anyscale.com>
1 parent 27a6994 commit b375907

File tree

7 files changed

+204
-122
lines changed

7 files changed

+204
-122
lines changed

rllib/connectors/common/numpy_to_tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Optional
1+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
22

33
import gymnasium as gym
44

@@ -12,6 +12,9 @@
1212
from ray.rllib.utils.typing import EpisodeType
1313
from ray.util.annotations import PublicAPI
1414

15+
if TYPE_CHECKING:
16+
from ray.rllib.utils.typing import DeviceType
17+
1518

1619
@PublicAPI(stability="alpha")
1720
class NumpyToTensor(ConnectorV2):
@@ -59,7 +62,7 @@ def __init__(
5962
input_action_space: Optional[gym.Space] = None,
6063
*,
6164
pin_memory: bool = False,
62-
device: Optional[str] = None,
65+
device: Optional["DeviceType"] = None,
6366
**kwargs,
6467
):
6568
"""Initializes a NumpyToTensor instance.

rllib/core/learner/differentiable_learner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import abc
22
import logging
3-
import numpy
43
from typing import (
4+
TYPE_CHECKING,
55
Any,
66
Collection,
77
Dict,
88
Iterable,
99
Optional,
1010
Tuple,
11-
TYPE_CHECKING,
1211
Union,
1312
)
1413

14+
import numpy
15+
1516
from ray.rllib.connectors.learner.learner_connector_pipeline import (
1617
LearnerConnectorPipeline,
1718
)
@@ -22,19 +23,19 @@
2223
from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch
2324
from ray.rllib.utils import unflatten_dict
2425
from ray.rllib.utils.annotations import (
25-
override,
2626
OverrideToImplementCustomLogic,
2727
OverrideToImplementCustomLogic_CallToSuperRecommended,
28+
override,
2829
)
2930
from ray.rllib.utils.checkpoints import Checkpointable
3031
from ray.rllib.utils.metrics import (
3132
DATASET_NUM_ITERS_TRAINED,
3233
DATASET_NUM_ITERS_TRAINED_LIFETIME,
34+
MODULE_TRAIN_BATCH_SIZE_MEAN,
3335
NUM_ENV_STEPS_TRAINED,
3436
NUM_ENV_STEPS_TRAINED_LIFETIME,
3537
NUM_MODULE_STEPS_TRAINED,
3638
NUM_MODULE_STEPS_TRAINED_LIFETIME,
37-
MODULE_TRAIN_BATCH_SIZE_MEAN,
3839
WEIGHTS_SEQ_NO,
3940
)
4041
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
@@ -124,7 +125,7 @@ def build(self, device: Optional[DeviceType] = None) -> None:
124125
if self._is_built:
125126
logger.debug("DifferentiableLearner already built. Skipping built.")
126127

127-
# If a dvice was passed, set the `DifferentiableLearner`'s device.
128+
# If a device was passed, set the `DifferentiableLearner`'s device.
128129
if device:
129130
self._device = device
130131

rllib/offline/offline_evaluation_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,11 @@ def set_device(self):
389389
try:
390390
self.__device = get_device(
391391
self.config,
392-
0
393-
if not self.worker_index
394-
else self.config.num_gpus_per_offline_eval_runner,
392+
(
393+
0
394+
if not self.worker_index
395+
else self.config.num_gpus_per_offline_eval_runner
396+
),
395397
)
396398
except NotImplementedError:
397399
self.__device = None
@@ -456,7 +458,7 @@ def _batch_iterator(self) -> MiniBatchRayDataIterator:
456458
return self.__batch_iterator
457459

458460
@property
459-
def _device(self) -> DeviceType:
461+
def _device(self) -> Union[DeviceType, None]:
460462
return self.__device
461463

462464
@property

rllib/offline/offline_policy_evaluation_runner.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,11 @@ def __call__(self, batch: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]:
102102
# TODO (simon): Refactor into a single code block for both cases.
103103
episodes = self.episode_buffer.sample(
104104
num_items=self.config.train_batch_size_per_learner,
105-
batch_length_T=self.config.model_config.get("max_seq_len", 0)
106-
if self._module.is_stateful()
107-
else None,
105+
batch_length_T=(
106+
self.config.model_config.get("max_seq_len", 0)
107+
if self._module.is_stateful()
108+
else None
109+
),
108110
n_step=self.config.get("n_step", 1) or 1,
109111
# TODO (simon): This can be removed as soon as DreamerV3 has been
110112
# cleaned up, i.e. can use episode samples for training.
@@ -131,9 +133,11 @@ def __call__(self, batch: Dict[str, numpy.ndarray]) -> Dict[str, numpy.ndarray]:
131133
# Sample steps from the buffer.
132134
episodes = self.episode_buffer.sample(
133135
num_items=self.config.train_batch_size_per_learner,
134-
batch_length_T=self.config.model_config.get("max_seq_len", 0)
135-
if self._module.is_stateful()
136-
else None,
136+
batch_length_T=(
137+
self.config.model_config.get("max_seq_len", 0)
138+
if self._module.is_stateful()
139+
else None
140+
),
137141
n_step=self.config.get("n_step", 1) or 1,
138142
# TODO (simon): This can be removed as soon as DreamerV3 has been
139143
# cleaned up, i.e. can use episode samples for training.
@@ -241,14 +245,14 @@ def _create_batch_iterator(self, **kwargs) -> Iterable:
241245
# Define the collate function that converts the flattened dictionary
242246
# to a `MultiAgentBatch` with Tensors.
243247
def _collate_fn(
244-
_batch: Dict[str, numpy.ndarray]
248+
_batch: Dict[str, numpy.ndarray],
245249
) -> Dict[EpisodeID, Dict[str, numpy.ndarray]]:
246250

247251
return _batch["episodes"]
248252

249253
# Define the finalize function that makes the host-to-device transfer.
250254
def _finalize_fn(
251-
_batch: Dict[EpisodeID, Dict[str, numpy.ndarray]]
255+
_batch: Dict[EpisodeID, Dict[str, numpy.ndarray]],
252256
) -> Dict[EpisodeID, Dict[str, TensorType]]:
253257

254258
return [
@@ -556,9 +560,11 @@ def set_device(self):
556560
try:
557561
self.__device = get_device(
558562
self.config,
559-
0
560-
if not self.worker_index
561-
else self.config.num_gpus_per_offline_eval_runner,
563+
(
564+
0
565+
if not self.worker_index
566+
else self.config.num_gpus_per_offline_eval_runner
567+
),
562568
)
563569
except NotImplementedError:
564570
self.__device = None
@@ -613,7 +619,7 @@ def _batch_iterator(self) -> MiniBatchRayDataIterator:
613619
return self.__batch_iterator
614620

615621
@property
616-
def _device(self) -> DeviceType:
622+
def _device(self) -> Union[DeviceType, None]:
617623
return self.__device
618624

619625
@property

rllib/utils/from_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@
1010
from ray.rllib.utils import force_list, merge_dicts
1111
from ray.rllib.utils.annotations import DeveloperAPI
1212

13+
from typing import Optional, TYPE_CHECKING
14+
15+
if TYPE_CHECKING:
16+
from ray.rllib.utils.typing import FromConfigSpec
17+
1318

1419
@DeveloperAPI
15-
def from_config(cls, config=None, **kwargs):
20+
def from_config(cls, config: Optional["FromConfigSpec"] = None, **kwargs):
1621
"""Uses the given config to create an object.
1722
1823
If `config` is a dict, an optional "type" key can be used as a

rllib/utils/runners/runner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import logging
3-
from typing import TYPE_CHECKING, Any
3+
4+
from typing import TYPE_CHECKING, Any, Union
45

56
from ray.rllib.utils.actor_manager import FaultAwareApply
67
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
@@ -86,8 +87,8 @@ def stop(self) -> None:
8687

8788
@property
8889
@abc.abstractmethod
89-
def _device(self) -> DeviceType:
90-
"""Returns the device of this `Runner`."""
90+
def _device(self) -> Union[DeviceType, None]:
91+
"""Returns the device of this `Runner`. None if framework is not supported."""
9192
pass
9293

9394
@abc.abstractmethod

0 commit comments

Comments
 (0)