diff --git a/.flake8 b/.flake8 index 21e86cbc5cdc..cb208f586035 100644 --- a/.flake8 +++ b/.flake8 @@ -44,6 +44,8 @@ ignore = B015 B016 B017 + B023 + B026 avoid-escape = no # Error E731 is ignored because of the migration from YAPF to Black. # See https://github.com/ray-project/ray/issues/21315 for more information. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb360d19ed16..5f8c13907fe2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,14 +47,14 @@ repos: types_or: [python] - repo: https://github.com/pycqa/flake8 - rev: 3.9.1 + rev: 7.1.1 hooks: - id: flake8 additional_dependencies: [ - flake8-comprehensions==3.10.1, - flake8-quotes==2.0.0, - flake8-bugbear==21.9.2, + flake8-comprehensions==3.15.0, + flake8-quotes==3.4.0, + flake8-bugbear==24.8.19, ] - repo: https://github.com/pre-commit/mirrors-prettier diff --git a/ci/lint/format.sh b/ci/lint/format.sh index bb3923918286..83b5178e209a 100755 --- a/ci/lint/format.sh +++ b/ci/lint/format.sh @@ -5,7 +5,7 @@ # Cause the script to exit if a single command fails set -euo pipefail -FLAKE8_VERSION_REQUIRED="3.9.1" +FLAKE8_VERSION_REQUIRED="7.1.1" BLACK_VERSION_REQUIRED="22.10.0" SHELLCHECK_VERSION_REQUIRED="0.7.1" MYPY_VERSION_REQUIRED="1.7.0" diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 09ec5c116261..238b1a507525 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -776,8 +776,8 @@ def render_library_examples(config: pathlib.Path = None) -> bs4.BeautifulSoup: soup.append(page_text) container = soup.new_tag("div", attrs={"class": "example-index"}) - for group, examples in examples.items(): - if not examples: + for group, _examples in examples.items(): + if not _examples: continue header = soup.new_tag("h2", attrs={"class": "example-header"}) @@ -810,7 +810,7 @@ def render_library_examples(config: pathlib.Path = None) -> bs4.BeautifulSoup: table.append(thead) tbody = soup.new_tag("tbody") - for example in examples: + for _example in _examples: tr = soup.new_tag("tr") # The columns specify which attributes of each example to show; @@ -821,7 +821,7 @@ def render_library_examples(config: pathlib.Path = None) -> bs4.BeautifulSoup: col_td = soup.new_tag("td") col_p = soup.new_tag("p") - attribute_value = getattr(example, attribute, "") + attribute_value = getattr(_example, attribute, "") if isinstance(attribute_value, str): col_p.append(attribute_value) elif isinstance(attribute_value, list): @@ -834,14 +834,14 @@ def render_library_examples(config: pathlib.Path = None) -> bs4.BeautifulSoup: link_td = soup.new_tag("td") link_p = soup.new_tag("p") - if example.link.startswith("http"): - link_href = soup.new_tag("a", attrs={"href": example.link}) + if _example.link.startswith("http"): + link_href = soup.new_tag("a", attrs={"href": _example.link}) else: link_href = soup.new_tag( - "a", attrs={"href": context["pathto"](example.link)} + "a", attrs={"href": context["pathto"](_example.link)} ) link_span = soup.new_tag("span") - link_span.append(example.title) + link_span.append(_example.title) link_href.append(link_span) link_p.append(link_href) link_td.append(link_p) diff --git a/doc/source/serve/doc_code/tutorial_sklearn.py b/doc/source/serve/doc_code/tutorial_sklearn.py index 923e50c591df..f0f8075267ca 100644 --- a/doc/source/serve/doc_code/tutorial_sklearn.py +++ b/doc/source/serve/doc_code/tutorial_sklearn.py @@ -28,7 +28,8 @@ iris_dataset["target_names"], ) -np.random.shuffle(data), np.random.shuffle(target) +np.random.shuffle(data) +np.random.shuffle(target) train_x, train_y = data[:100], target[:100] val_x, val_y = data[100:], target[100:] # __doc_data_end__ diff --git a/python/ray/_private/accelerators/accelerator.py b/python/ray/_private/accelerators/accelerator.py index 70178094e14c..60cf2ee1b40e 100644 --- a/python/ray/_private/accelerators/accelerator.py +++ b/python/ray/_private/accelerators/accelerator.py @@ -86,7 +86,9 @@ def validate_resource_request_quantity( @staticmethod @abstractmethod def get_current_process_visible_accelerator_ids() -> Optional[List[str]]: - """Get the ids of accelerators of this family that are visible to the current process. + """ + Get the ids of accelerators of this family + that are visible to the current process. Returns: The list of visiable accelerator ids. @@ -96,7 +98,9 @@ def get_current_process_visible_accelerator_ids() -> Optional[List[str]]: @staticmethod @abstractmethod def set_current_process_visible_accelerator_ids(ids: List[str]) -> None: - """Set the ids of accelerators of this family that are visible to the current process. + """ + Set the ids of accelerators of this family + that are visible to the current process. Args: ids: The ids of visible accelerators of this family. @@ -106,7 +110,8 @@ def set_current_process_visible_accelerator_ids(ids: List[str]) -> None: def get_ec2_instance_num_accelerators( instance_type: str, instances: dict ) -> Optional[int]: - """Get the number of accelerators of this family on ec2 instance with given type. + """ + Get the number of accelerators of this family on ec2 instance with given type. Args: instance_type: The ec2 instance type. diff --git a/python/ray/_private/node.py b/python/ray/_private/node.py index 156ecd0be606..290b14e5d829 100644 --- a/python/ray/_private/node.py +++ b/python/ray/_private/node.py @@ -534,7 +534,8 @@ def get_resource_spec(self): """Resolve and return the current resource spec for the node.""" def merge_resources(env_dict, params_dict): - """Separates special case params and merges two dictionaries, picking from the + """ + Separates special case params and merges two dictionaries, picking from the first in the event of a conflict. Also emit a warning on every conflict. """ diff --git a/python/ray/_private/runtime_env/plugin.py b/python/ray/_private/runtime_env/plugin.py index a1e03a507b59..1b55e38475f0 100644 --- a/python/ray/_private/runtime_env/plugin.py +++ b/python/ray/_private/runtime_env/plugin.py @@ -1,7 +1,7 @@ import logging import os import json -from abc import ABC +from abc import ABC, abstractmethod from typing import List, Dict, Optional, Any, Type from ray._private.runtime_env.context import RuntimeEnvContext @@ -28,6 +28,7 @@ class RuntimeEnvPlugin(ABC): priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY @staticmethod + @abstractmethod def validate(runtime_env_dict: dict) -> None: """Validate user entry for this plugin. diff --git a/python/ray/_private/runtime_env/utils.py b/python/ray/_private/runtime_env/utils.py index a4387ff27b69..11de7717eef0 100644 --- a/python/ray/_private/runtime_env/utils.py +++ b/python/ray/_private/runtime_env/utils.py @@ -93,7 +93,7 @@ async def check_output_cmd( # since Python 3.9, when cancelled, the inner process needs to throw as it is # for asyncio to timeout properly https://bugs.python.org/issue40607 raise e - except BaseException as e: + except BaseException as e: # noqa: B036 To avoid breaking change raise RuntimeError(f"Run cmd[{cmd_index}] got exception.") from e else: stdout = stdout.decode("utf-8") diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 0de0ba78405f..7d94a008bee8 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1608,7 +1608,7 @@ def start_raylet( Returns: ProcessInfo for the process that was started. """ - assert node_manager_port is not None and type(node_manager_port) == int + assert node_manager_port is not None and isinstance(node_manager_port, int) if use_valgrind and use_profiler: raise ValueError("Cannot use valgrind and profiler at the same time.") diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index 1eb26e0fad25..4fdf6be9ef42 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -790,7 +790,7 @@ def wait_until_succeeded_without_exception( Return: Whether exception occurs within a timeout. """ - if type(exceptions) != tuple: + if not isinstance(exceptions, tuple): raise Exception("exceptions arguments should be given as a tuple") time_elapsed = 0 diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 0abfb5757692..78fad3b5571a 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -1220,13 +1220,18 @@ def __getitem__(self, key): def __len__(self): if log_once("ray_context_len"): - warnings.warn("len(ctx) is deprecated. Use len(ctx.address_info) instead.") + warnings.warn( + "len(ctx) is deprecated. Use len(ctx.address_info) instead.", + DeprecationWarning, + stacklevel=2, + ) return len(self.address_info) def __iter__(self): if log_once("ray_context_len"): warnings.warn( - "iter(ctx) is deprecated. Use iter(ctx.address_info) instead." + "iter(ctx) is deprecated. Use iter(ctx.address_info) instead.", + stacklevel=2, ) return iter(self.address_info) @@ -2762,9 +2767,9 @@ def get( port=None, patch_stdstreams=False, quiet=None, - breakpoint_uuid=debugger_breakpoint.decode() - if debugger_breakpoint - else None, + breakpoint_uuid=( + debugger_breakpoint.decode() if debugger_breakpoint else None + ), debugger_external=worker.ray_debugger_external, ) rdb.set_trace(frame=frame) diff --git a/python/ray/air/_internal/device_manager/torch_device_manager.py b/python/ray/air/_internal/device_manager/torch_device_manager.py index d522a477ef58..d00bec53d2ca 100644 --- a/python/ray/air/_internal/device_manager/torch_device_manager.py +++ b/python/ray/air/_internal/device_manager/torch_device_manager.py @@ -1,4 +1,4 @@ -from abc import ABC +from abc import ABC, abstractmethod from typing import List, Union import torch @@ -9,32 +9,39 @@ class TorchDeviceManager(ABC): an acclerator family in Ray AI Library. """ + @abstractmethod def is_available(self) -> bool: """Validate if device is available.""" ... + @abstractmethod def get_devices(self) -> List[torch.device]: """Gets the correct torch device configured for this process""" ... + @abstractmethod def set_device(self, device: Union[torch.device, int, str, None]): """Set the correct device for this process""" ... + @abstractmethod def supports_stream(self) -> bool: """Validate if the device type support create a stream""" ... + @abstractmethod def create_stream(self, device: torch.device): """Create a device stream""" ... + @abstractmethod def get_stream_context(self, stream): """Get a stream context of device. If device didn't support stream, this should return a empty context manager instead of None. """ ... + @abstractmethod def get_current_stream(self): """Get current stream on accelerators like torch.cuda.current_stream""" ... diff --git a/python/ray/air/_internal/util.py b/python/ray/air/_internal/util.py index 2eb4463fa05c..94b9d8d7bf80 100644 --- a/python/ray/air/_internal/util.py +++ b/python/ray/air/_internal/util.py @@ -122,7 +122,7 @@ def run(self): else: # If non-zero exit code, then raise exception to main thread. self._propagate_exception(e) - except BaseException as e: + except BaseException as e: # noqa: B036 To avoid breaking change # Propagate all other exceptions to the main thread. self._propagate_exception(e) diff --git a/python/ray/air/execution/resources/resource_manager.py b/python/ray/air/execution/resources/resource_manager.py index daa2cdd69215..92a7b2bd7a01 100644 --- a/python/ray/air/execution/resources/resource_manager.py +++ b/python/ray/air/execution/resources/resource_manager.py @@ -62,6 +62,7 @@ class ResourceManager(abc.ABC): """ + @abc.abstractmethod def request_resources(self, resource_request: ResourceRequest): """Request resources. @@ -75,6 +76,7 @@ def request_resources(self, resource_request: ResourceRequest): """ raise NotImplementedError + @abc.abstractmethod def cancel_resource_request(self, resource_request: ResourceRequest): """Cancel resource request. @@ -84,10 +86,12 @@ def cancel_resource_request(self, resource_request: ResourceRequest): """ raise NotImplementedError + @abc.abstractmethod def has_resources_ready(self, resource_request: ResourceRequest) -> bool: """Returns True if resources for the given request are ready to be acquired.""" raise NotImplementedError + @abc.abstractmethod def acquire_resources( self, resource_request: ResourceRequest ) -> Optional[AcquiredResources]: @@ -98,6 +102,7 @@ def acquire_resources( """ raise NotImplementedError + @abc.abstractmethod def free_resources(self, acquired_resource: AcquiredResources): """Free acquired resources from usage and return them to the resource manager. @@ -108,6 +113,7 @@ def free_resources(self, acquired_resource: AcquiredResources): """ raise NotImplementedError + @abc.abstractmethod def get_resource_futures(self) -> List[ray.ObjectRef]: """Return futures for resources to await. @@ -120,6 +126,7 @@ def get_resource_futures(self) -> List[ray.ObjectRef]: """ return [] + @abc.abstractmethod def update_state(self): """Update internal state of the resource manager. @@ -132,6 +139,7 @@ def update_state(self): """ pass + @abc.abstractmethod def clear(self): """Reset internal state and clear all resources. diff --git a/python/ray/air/integrations/wandb.py b/python/ray/air/integrations/wandb.py index fcd683ec0e7f..53103448a194 100644 --- a/python/ray/air/integrations/wandb.py +++ b/python/ray/air/integrations/wandb.py @@ -558,6 +558,7 @@ def __init__( warnings.warn( "`save_checkpoints` is deprecated. Use `upload_checkpoints` instead.", DeprecationWarning, + stacklevel=2, ) upload_checkpoints = save_checkpoints diff --git a/python/ray/air/tests/test_integration_comet.py b/python/ray/air/tests/test_integration_comet.py index ad94cb33d0d9..b6692ee95beb 100644 --- a/python/ray/air/tests/test_integration_comet.py +++ b/python/ray/air/tests/test_integration_comet.py @@ -186,7 +186,7 @@ def test_kwargs_passthrough(self, experiment): logger.log_trial_start(trial) # These are the default kwargs that get passed to create the experiment - expected_kwargs = {kwarg: False for kwarg in logger._exclude_autolog} + expected_kwargs = dict.fromkeys(logger._exclude_autolog, False) expected_kwargs.update(experiment_kwargs) experiment.assert_called_with(**expected_kwargs) diff --git a/python/ray/air/tests/test_integration_wandb.py b/python/ray/air/tests/test_integration_wandb.py index abf1576407d7..42972e6cafd3 100644 --- a/python/ray/air/tests/test_integration_wandb.py +++ b/python/ray/air/tests/test_integration_wandb.py @@ -288,7 +288,7 @@ def test_wandb_logger_reporting(self, trial): def test_wandb_logger_auto_config_keys(self, trial): logger = WandbTestExperimentLogger(project="test_project", api_key="1234") logger.on_trial_start(iteration=0, trials=[], trial=trial) - result = {key: 0 for key in WandbLoggerCallback.AUTO_CONFIG_KEYS} + result = dict.fromkeys(WandbLoggerCallback.AUTO_CONFIG_KEYS, 0) logger.on_trial_result(0, [], trial, result) logger.on_trial_complete(0, [], trial) logger.on_experiment_end(trials=[trial]) @@ -314,7 +314,7 @@ def test_wandb_logger_exclude_config(self): logger.on_trial_start(iteration=0, trials=[], trial=trial) # We need to test that `excludes` also applies to `AUTO_CONFIG_KEYS`. - result = {key: 0 for key in WandbLoggerCallback.AUTO_CONFIG_KEYS} + result = dict.fromkeys(WandbLoggerCallback.AUTO_CONFIG_KEYS, 0) logger.on_trial_result(0, [], trial, result) logger.on_trial_complete(0, [], trial) logger.on_experiment_end(trials=[trial]) diff --git a/python/ray/air/util/data_batch_conversion.py b/python/ray/air/util/data_batch_conversion.py index e134b5b1d31f..5c76a58430a4 100644 --- a/python/ray/air/util/data_batch_conversion.py +++ b/python/ray/air/util/data_batch_conversion.py @@ -158,6 +158,7 @@ def convert_batch_type_to_pandas( "starting from Ray 2.4. All batch format conversions should be " "done manually instead of relying on this API.", PendingDeprecationWarning, + stacklevel=2, ) return _convert_batch_type_to_pandas( data=data, cast_tensor_columns=cast_tensor_columns @@ -186,6 +187,7 @@ def convert_pandas_to_batch_type( "starting from Ray 2.4. All batch format conversions should be " "done manually instead of relying on this API.", PendingDeprecationWarning, + stacklevel=2, ) return _convert_pandas_to_batch_type( data=data, type=type, cast_tensor_columns=cast_tensor_columns diff --git a/python/ray/air/util/torch_dist.py b/python/ray/air/util/torch_dist.py index 6a7316497710..318910e8a5e8 100644 --- a/python/ray/air/util/torch_dist.py +++ b/python/ray/air/util/torch_dist.py @@ -6,7 +6,7 @@ """ import os -from abc import ABC +from abc import ABC, abstractmethod from collections import defaultdict from datetime import timedelta from typing import Callable, List, T @@ -27,6 +27,7 @@ class TorchDistributedWorker(ABC): to be executed on a remote DDP worker. """ + @abstractmethod def execute(self, func: Callable[..., T], *args, **kwargs) -> T: """Executes the input function and returns the output. diff --git a/python/ray/autoscaler/_private/aws/config.py b/python/ray/autoscaler/_private/aws/config.py index a82deed3c1cd..93bb5165536d 100644 --- a/python/ray/autoscaler/_private/aws/config.py +++ b/python/ray/autoscaler/_private/aws/config.py @@ -688,7 +688,7 @@ def _check_ami(config): """Provide helpful message for missing ImageId for node configuration.""" # map from node type key -> source of ImageId field - ami_src_info = {key: "config" for key in config["available_node_types"]} + ami_src_info = dict.fromkeys(config["available_node_types"], "config") _set_config_info(ami_src=ami_src_info) region = config["provider"]["region"] diff --git a/python/ray/autoscaler/_private/cli_logger.py b/python/ray/autoscaler/_private/cli_logger.py index 01083be23eff..8b341aac896f 100644 --- a/python/ray/autoscaler/_private/cli_logger.py +++ b/python/ray/autoscaler/_private/cli_logger.py @@ -7,6 +7,7 @@ (depending on TTY features) as well as indentation and other structured output. """ + import inspect import logging import os diff --git a/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py b/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py index c685be58cf60..0a6c1cced89f 100644 --- a/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py +++ b/python/ray/autoscaler/v2/instance_manager/cloud_providers/kuberay/cloud_provider.py @@ -111,9 +111,7 @@ class ScaleRequest: def get_non_terminated(self) -> Dict[CloudInstanceId, CloudInstance]: self._sync_with_api_server() - return copy.deepcopy( - {id: instance for id, instance in self._cached_instances.items()} - ) + return copy.deepcopy(dict(self._cached_instances)) def terminate(self, ids: List[CloudInstanceId], request_id: str) -> None: if request_id in self._requests: diff --git a/python/ray/autoscaler/v2/tests/test_instance_util.py b/python/ray/autoscaler/v2/tests/test_instance_util.py index 23799b1ddc25..bb8002ea1077 100644 --- a/python/ray/autoscaler/v2/tests/test_instance_util.py +++ b/python/ray/autoscaler/v2/tests/test_instance_util.py @@ -282,7 +282,7 @@ def add_reachable_from(reachable, src, transitions): reachable[dst] if reachable[dst] is not None else set() ) - expected_reachable = {s: None for s in Instance.InstanceStatus.values()} + expected_reachable = dict.fromkeys(Instance.InstanceStatus.values(), None) # Error status and terminal status. expected_reachable[Instance.ALLOCATION_FAILED] = set() @@ -310,10 +310,10 @@ def add_reachable_from(reachable, src, transitions): # Add REQUESTED again since it's also reachable from QUEUED. add_reachable_from(expected_reachable, Instance.REQUESTED, transitions) - for s, expected_reachable in expected_reachable.items(): - assert InstanceUtil.get_reachable_statuses(s) == expected_reachable, ( + for s, _expected_reachable in expected_reachable.items(): + assert InstanceUtil.get_reachable_statuses(s) == _expected_reachable, ( f"reachable_from({s}) = {InstanceUtil.get_reachable_statuses(s)} " - f"!= {expected_reachable}" + f"!= {_expected_reachable}" ) diff --git a/python/ray/dag/compiled_dag_node.py b/python/ray/dag/compiled_dag_node.py index 0e56c1f4ea52..b58bee33333b 100644 --- a/python/ray/dag/compiled_dag_node.py +++ b/python/ray/dag/compiled_dag_node.py @@ -852,7 +852,7 @@ def _preprocess(self) -> None: "the driver cannot participate in the NCCL group" ) - if type(dag_node.type_hint) == ChannelOutputType: + if isinstance(dag_node.type_hint, ChannelOutputType): # No type hint specified by the user. Replace # with the default type hint for this DAG. dag_node.with_type_hint(self._default_type_hint) diff --git a/python/ray/dag/tests/test_py_obj_scanner.py b/python/ray/dag/tests/test_py_obj_scanner.py index c07fdd499e38..989cf3eb1af6 100644 --- a/python/ray/dag/tests/test_py_obj_scanner.py +++ b/python/ray/dag/tests/test_py_obj_scanner.py @@ -15,7 +15,7 @@ def test_simple_replace(): found = scanner.find_nodes(my_objs) assert len(found) == 3 - replaced = scanner.replace_nodes({obj: 1 for obj in found}) + replaced = scanner.replace_nodes(dict.fromkeys(found, 1)) assert replaced == [1, [1, {"key": 1}]] @@ -51,7 +51,7 @@ def __eq__(self, other): found = scanner.find_nodes(my_objs) assert len(found) == 3 - replaced = scanner.replace_nodes({obj: 1 for obj in found}) + replaced = scanner.replace_nodes(dict.fromkeys(found, 1)) assert replaced == [Outer(1), Outer(Outer(1)), Outer((1,))] @@ -73,7 +73,7 @@ def call_find_and_replace_nodes(): scanner = _PyObjScanner(source_type=Source) my_objs = [Source(), [Source(), {"key": Source()}]] found = scanner.find_nodes(my_objs) - scanner.replace_nodes({obj: 1 for obj in found}) + scanner.replace_nodes(dict.fromkeys(found, 1)) scanner.clear() assert id(scanner) not in _instances diff --git a/python/ray/dashboard/head.py b/python/ray/dashboard/head.py index fe6e6cb4365f..149d86791f55 100644 --- a/python/ray/dashboard/head.py +++ b/python/ray/dashboard/head.py @@ -313,12 +313,13 @@ async def _async_notify(): # This could be done better in the future, including # removing the polling on the Ray side, by communicating the # server address to Ray via stdin / stdout or a pipe. + self.gcs_client.internal_kv_put( ray_constants.DASHBOARD_ADDRESS.encode(), f"{dashboard_http_host}:{http_port}".encode(), True, namespace=ray_constants.KV_NAMESPACE_DASHBOARD, - ), + ) self.gcs_client.internal_kv_put( dashboard_consts.DASHBOARD_RPC_ADDRESS.encode(), f"{self.ip}:{self.grpc_port}".encode(), diff --git a/python/ray/dashboard/modules/actor/actor_head.py b/python/ray/dashboard/modules/actor/actor_head.py index 990579a9c605..12af26f609af 100644 --- a/python/ray/dashboard/modules/actor/actor_head.py +++ b/python/ray/dashboard/modules/actor/actor_head.py @@ -144,9 +144,11 @@ async def __call__( reply = await self._gcs_actor_info_stub.GetAllActorInfo( request, timeout=timeout, - filters=gcs_service_pb2.GetAllActorInfoRequest.Filters(state=state) - if state - else None, + filters=( + gcs_service_pb2.GetAllActorInfoRequest.Filters(state=state) + if state + else None + ), ) if reply.status.code != 0: @@ -267,7 +269,9 @@ def _convert_to_dict(): def _process_updated_actor_table( self, actor_id: str, actor_table_data: Dict[str, Any] ): - """NOTE: This method has to be executed on the event-loop, provided that it accesses + """ + NOTE: This method has to be executed on the event-loop, + provided that it accesses DataSource data structures (to follow its thread-safety model)""" # If actor is not new registered but updated, we only update diff --git a/python/ray/dashboard/modules/node/node_head.py b/python/ray/dashboard/modules/node/node_head.py index e146d9409cb6..9478e95810e0 100644 --- a/python/ray/dashboard/modules/node/node_head.py +++ b/python/ray/dashboard/modules/node/node_head.py @@ -106,6 +106,7 @@ def create(cls, *args, **kwargs): else: return GetAllNodeInfoFromNewGcsClient(*args, **kwargs) + @abc.abstractmethod async def __call__( self, *, diff --git a/python/ray/dashboard/modules/reporter/tests/test_reporter.py b/python/ray/dashboard/modules/reporter/tests/test_reporter.py index 839684d9d49d..6ed34195e4be 100644 --- a/python/ray/dashboard/modules/reporter/tests/test_reporter.py +++ b/python/ray/dashboard/modules/reporter/tests/test_reporter.py @@ -720,7 +720,7 @@ def _generate_worker_key(self, proc): children_pids = {p.pid for p in children} workers = ReporterAgent._get_workers(obj) # In the first run, the percent should be 0. - assert all([worker["cpu_percent"] == 0.0 for worker in workers]) + assert all(worker["cpu_percent"] == 0.0 for worker in workers) for _ in range(10): time.sleep(0.1) workers = ReporterAgent._get_workers(obj) diff --git a/python/ray/dashboard/tests/test_dashboard.py b/python/ray/dashboard/tests/test_dashboard.py index 4b61accc53c3..6b283c8065a1 100644 --- a/python/ray/dashboard/tests/test_dashboard.py +++ b/python/ray/dashboard/tests/test_dashboard.py @@ -783,7 +783,7 @@ def test_immutable_types(): d["list"][0] = {str(i): i for i in range(1000)} d["dict"] = {str(i): i for i in range(1000)} immutable_dict = dashboard_utils.make_immutable(d) - assert type(immutable_dict) == dashboard_utils.ImmutableDict + assert isinstance(immutable_dict, dashboard_utils.ImmutableDict) assert immutable_dict == dashboard_utils.ImmutableDict(d) assert immutable_dict == d assert dashboard_utils.ImmutableDict(immutable_dict) == immutable_dict @@ -801,8 +801,8 @@ def test_immutable_types(): # Test json dumps / loads json_str = json.dumps(immutable_dict, cls=dashboard_optional_utils.CustomEncoder) deserialized_immutable_dict = json.loads(json_str) - assert type(deserialized_immutable_dict) == dict - assert type(deserialized_immutable_dict["list"]) == list + assert isinstance(deserialized_immutable_dict, dict) + assert isinstance(deserialized_immutable_dict["list"], list) assert immutable_dict.mutable() == deserialized_immutable_dict dashboard_optional_utils.rest_response(True, "OK", data=immutable_dict) dashboard_optional_utils.rest_response(True, "OK", **immutable_dict) @@ -815,12 +815,12 @@ def test_immutable_types(): # Test get default immutable immutable_default_value = immutable_dict.get("not exist list", [1, 2]) - assert type(immutable_default_value) == dashboard_utils.ImmutableList + assert isinstance(immutable_default_value, dashboard_utils.ImmutableList) # Test recursive immutable - assert type(immutable_dict["list"]) == dashboard_utils.ImmutableList - assert type(immutable_dict["dict"]) == dashboard_utils.ImmutableDict - assert type(immutable_dict["list"][0]) == dashboard_utils.ImmutableDict + assert isinstance(immutable_dict["list"], dashboard_utils.ImmutableList) + assert isinstance(immutable_dict["dict"], dashboard_utils.ImmutableDict) + assert isinstance(immutable_dict["list"][0], dashboard_utils.ImmutableDict) # Test exception with pytest.raises(TypeError): diff --git a/python/ray/data/_internal/block_batching/interfaces.py b/python/ray/data/_internal/block_batching/interfaces.py index 452b6d850b93..41c4f54475e4 100644 --- a/python/ray/data/_internal/block_batching/interfaces.py +++ b/python/ray/data/_internal/block_batching/interfaces.py @@ -42,6 +42,7 @@ def prefetch_blocks(self, blocks: List[ObjectRef[Block]]): """Prefetch the provided blocks to this node.""" pass + @abc.abstractmethod def stop(self): """Stop prefetching and release resources.""" pass diff --git a/python/ray/data/_internal/datasource/range_datasource.py b/python/ray/data/_internal/datasource/range_datasource.py index 50dedbdf2fed..c6fcddf7a776 100644 --- a/python/ray/data/_internal/datasource/range_datasource.py +++ b/python/ray/data/_internal/datasource/range_datasource.py @@ -112,7 +112,7 @@ def make_blocks( return read_tasks - @functools.cache + @functools.cache # noqa: B019 - Ignore to prevent GC issues def _schema(self): if self._n == 0: return None diff --git a/python/ray/data/_internal/datasource/sql_datasource.py b/python/ray/data/_internal/datasource/sql_datasource.py index 8a6b5713fbbe..c117816b8130 100644 --- a/python/ray/data/_internal/datasource/sql_datasource.py +++ b/python/ray/data/_internal/datasource/sql_datasource.py @@ -81,7 +81,7 @@ def __init__(self, sql: str, connection_factory: Callable[[], Connection]): self.connection_factory = connection_factory def estimate_inmemory_data_size(self) -> Optional[int]: - None + pass def get_read_tasks(self, parallelism: int) -> List[ReadTask]: def fallback_read_fn() -> Iterable[Block]: diff --git a/python/ray/data/_internal/datasource/tfrecords_datasource.py b/python/ray/data/_internal/datasource/tfrecords_datasource.py index e5aa8b478626..5e3a20cf19ae 100644 --- a/python/ray/data/_internal/datasource/tfrecords_datasource.py +++ b/python/ray/data/_internal/datasource/tfrecords_datasource.py @@ -355,7 +355,7 @@ def _cast_large_list_to_list(batch: pyarrow.Table): for column_name in old_schema.names: field_type = old_schema.field(column_name).type - if type(field_type) == pyarrow.lib.LargeListType: + if isinstance(field_type, pyarrow.lib.LargeListType): value_type = field_type.value_type if value_type == pyarrow.large_binary(): @@ -416,7 +416,7 @@ def __init__(self, columns: List[str]): ) def _init(self, k: str): - return {col: 0 for col in self._columns} + return dict.fromkeys(self._columns, 0) def _merge(self, acc1: Dict[str, int], acc2: Dict[str, int]): merged = {} diff --git a/python/ray/data/_internal/execution/operators/map_operator.py b/python/ray/data/_internal/execution/operators/map_operator.py index a3d98bbeee3b..84d81ef0302f 100644 --- a/python/ray/data/_internal/execution/operators/map_operator.py +++ b/python/ray/data/_internal/execution/operators/map_operator.py @@ -567,6 +567,7 @@ def notify_task_output_ready(self, task_index: int, output: RefBundle): """Called when a task's output is ready.""" pass + @abstractmethod def notify_task_completed(self, task_index: int): """Called when a previously pending task completes.""" pass diff --git a/python/ray/data/_internal/execution/resource_manager.py b/python/ray/data/_internal/execution/resource_manager.py index 60cf3c4e7184..8edb90d50884 100644 --- a/python/ray/data/_internal/execution/resource_manager.py +++ b/python/ray/data/_internal/execution/resource_manager.py @@ -556,8 +556,9 @@ def _get_downstream_ineligible_ops( def _get_downstream_eligible_ops( self, op: PhysicalOperator ) -> Iterable[PhysicalOperator]: - """Get the downstream eligible operators of the given operator, ignoring intermediate - ineligible operators. + """ + Get the downstream eligible operators of the given operator, + ignoring intermediate ineligible operators. E.g., - "cur_map->downstream_map" will return [downstream_map]. diff --git a/python/ray/data/_internal/execution/streaming_executor.py b/python/ray/data/_internal/execution/streaming_executor.py index 238f6f9421cc..1b06dfb8c989 100644 --- a/python/ray/data/_internal/execution/streaming_executor.py +++ b/python/ray/data/_internal/execution/streaming_executor.py @@ -134,7 +134,7 @@ def execute( self._execution_id, ) - self._has_op_completed = {op: False for op in self._topology} + self._has_op_completed = dict.fromkeys(self._topology, False) self._output_node: OpState = self._topology[dag] StatsManager.register_dataset_to_stats_actor( diff --git a/python/ray/data/_internal/logical/operators/from_operators.py b/python/ray/data/_internal/logical/operators/from_operators.py index 7d2b07979c22..71eac7b145a6 100644 --- a/python/ray/data/_internal/logical/operators/from_operators.py +++ b/python/ray/data/_internal/logical/operators/from_operators.py @@ -40,7 +40,7 @@ def input_data(self) -> List[RefBundle]: def output_data(self) -> Optional[List[RefBundle]]: return self._input_data - @functools.cache + @functools.cache # noqa: B019 - Ignore to prevent GC issues def aggregate_output_metadata(self) -> BlockMetadata: return BlockMetadata( num_rows=self._num_rows(), diff --git a/python/ray/data/_internal/logical/operators/input_data_operator.py b/python/ray/data/_internal/logical/operators/input_data_operator.py index 6592972e7c94..d84bed1b2fe8 100644 --- a/python/ray/data/_internal/logical/operators/input_data_operator.py +++ b/python/ray/data/_internal/logical/operators/input_data_operator.py @@ -33,7 +33,7 @@ def output_data(self) -> Optional[List[RefBundle]]: return None return self.input_data - @functools.cache + @functools.cache # noqa: B019 - Ignore to prevent GC issues def aggregate_output_metadata(self) -> BlockMetadata: if self.input_data is None: return BlockMetadata(None, None, None, None, None) diff --git a/python/ray/data/_internal/logical/operators/read_operator.py b/python/ray/data/_internal/logical/operators/read_operator.py index b75a314ea6cb..7aa6ea055413 100644 --- a/python/ray/data/_internal/logical/operators/read_operator.py +++ b/python/ray/data/_internal/logical/operators/read_operator.py @@ -46,7 +46,7 @@ def get_detected_parallelism(self) -> int: """ return self._detected_parallelism - @functools.cache + @functools.cache # noqa: B019 - Ignore to prevent GC issues def aggregate_output_metadata(self) -> BlockMetadata: """A ``BlockMetadata`` that represents the aggregate metadata of the outputs. diff --git a/python/ray/data/_internal/planner/plan_read_op.py b/python/ray/data/_internal/planner/plan_read_op.py index 94dbae1f3871..da4e7afccb1f 100644 --- a/python/ray/data/_internal/planner/plan_read_op.py +++ b/python/ray/data/_internal/planner/plan_read_op.py @@ -42,7 +42,8 @@ def cleaned_metadata(read_task: ReadTask, read_task_ref) -> BlockMetadata: f"'{read_task.read_fn.__name__}' is {memory_string(task_size)}. This size " "relatively large. As a result, Ray might excessively " "spill objects during execution. To fix this issue, avoid accessing " - f"`self` or other large objects in '{read_task.read_fn.__name__}'." + f"`self` or other large objects in '{read_task.read_fn.__name__}'.", + stacklevel=2, ) # Defensively compute the size of the block as the max size reported by the diff --git a/python/ray/data/_internal/planner/random_shuffle.py b/python/ray/data/_internal/planner/random_shuffle.py index 88e5b255cd0e..fc5ca19e86ef 100644 --- a/python/ray/data/_internal/planner/random_shuffle.py +++ b/python/ray/data/_internal/planner/random_shuffle.py @@ -35,7 +35,6 @@ def fn( # MapOperator->AllToAllOperator), we pass a map function which # is applied to each block before shuffling. map_transformer: Optional[MapTransformer] = ctx.upstream_map_transformer - upstream_map_fn = None nonlocal ray_remote_args if map_transformer: # NOTE(swang): We override the target block size with infinity, to @@ -53,7 +52,8 @@ def upstream_map_fn(blocks): # If there is a fused upstream operator, # also use the ray_remote_args from the fused upstream operator. ray_remote_args = ctx.upstream_map_ray_remote_args - + else: + upstream_map_fn = None shuffle_spec = ShuffleTaskSpec( ctx.target_max_block_size, random_shuffle=True, diff --git a/python/ray/data/_internal/planner/repartition.py b/python/ray/data/_internal/planner/repartition.py index 73059f703535..5047692adbc3 100644 --- a/python/ray/data/_internal/planner/repartition.py +++ b/python/ray/data/_internal/planner/repartition.py @@ -35,7 +35,6 @@ def shuffle_repartition_fn( # MapOperator->AllToAllOperator), we pass a map function which # is applied to each block before shuffling. map_transformer: Optional["MapTransformer"] = ctx.upstream_map_transformer - upstream_map_fn = None if map_transformer: # NOTE(swang): We override the target block size with infinity, to # prevent the upstream map from slicing its output into smaller @@ -49,6 +48,9 @@ def shuffle_repartition_fn( def upstream_map_fn(blocks): return map_transformer.apply_transform(blocks, ctx) + else: + upstream_map_fn = None + shuffle_spec = ShuffleTaskSpec( ctx.target_max_block_size, random_shuffle=False, diff --git a/python/ray/data/_internal/split.py b/python/ray/data/_internal/split.py index 3f7fe145af09..f78654ebc65c 100644 --- a/python/ray/data/_internal/split.py +++ b/python/ray/data/_internal/split.py @@ -48,7 +48,9 @@ def _generate_per_block_split_indices( num_rows_per_block: List[int], split_indices: List[int], ) -> List[List[int]]: - """Given num rows per block and valid split indices, generate per block split indices. + """ + Given num rows per block and valid split indices, generate + per block split indices. Args: num_rows_per_block: num of rows per block. diff --git a/python/ray/data/_internal/stats.py b/python/ray/data/_internal/stats.py index a54810f9ab16..9f1f0d91c9fd 100644 --- a/python/ray/data/_internal/stats.py +++ b/python/ray/data/_internal/stats.py @@ -1239,7 +1239,8 @@ def from_block_metadata( ) def __str__(self) -> str: - """For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from + """ + For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from `OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string that summarizes operator execution statistics. @@ -1361,7 +1362,8 @@ def __str__(self) -> str: return out def __repr__(self, level=0) -> str: - """For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from + """ + For a given (pre-calculated) `OperatorStatsSummary` object (e.g. generated from `OperatorStatsSummary.from_block_metadata()`), returns a human-friendly string that summarizes operator execution statistics. diff --git a/python/ray/data/_internal/util.py b/python/ray/data/_internal/util.py index 0a9d975b00c0..81a110c3825f 100644 --- a/python/ray/data/_internal/util.py +++ b/python/ray/data/_internal/util.py @@ -972,9 +972,7 @@ def call_with_retry( try: return f() except Exception as e: - is_retryable = match is None or any( - [pattern in str(e) for pattern in match] - ) + is_retryable = match is None or any(pattern in str(e) for pattern in match) if is_retryable and i + 1 < max_attempts: # Retry with binary expoential backoff with random jitter. backoff = min((2 ** (i + 1)), max_backoff_s) * random.random() @@ -1023,9 +1021,7 @@ def iterate_with_retry( yield item return except Exception as e: - is_retryable = match is None or any( - [pattern in str(e) for pattern in match] - ) + is_retryable = match is None or any(pattern in str(e) for pattern in match) if is_retryable and i + 1 < max_attempts: # Retry with binary expoential backoff with random jitter. backoff = min((2 ** (i + 1)), max_backoff_s) * random.random() diff --git a/python/ray/data/context.py b/python/ray/data/context.py index f43fc2a50246..2714eebe2f61 100644 --- a/python/ray/data/context.py +++ b/python/ray/data/context.py @@ -376,6 +376,7 @@ def __setattr__(self, name: str, value: Any) -> None: "`write_file_retry_on_errors` is deprecated. Configure " "`retried_io_errors` instead.", DeprecationWarning, + stacklevel=2, ) super().__setattr__(name, value) diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index 2316afe958b8..f4b97addaca7 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -3643,7 +3643,8 @@ def write_bigquery( if ray_remote_args.get("max_retries", 0) != 0: warnings.warn( "The max_retries of a BigQuery Write Task should be set to 0" - " to avoid duplicate writes." + " to avoid duplicate writes.", + stacklevel=2, ) else: ray_remote_args["max_retries"] = 0 diff --git a/python/ray/data/datasource/datasink.py b/python/ray/data/datasource/datasink.py index 0832e0539fd1..723aa05d9d45 100644 --- a/python/ray/data/datasource/datasink.py +++ b/python/ray/data/datasource/datasink.py @@ -136,7 +136,8 @@ def supports_distributed_writes(self) -> bool: @property def num_rows_per_write(self) -> Optional[int]: - """The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call. + """ + The target number of rows to pass to each :meth:`~ray.data.Datasink.write` call. If ``None``, Ray Data passes a system-chosen number of rows. """ diff --git a/python/ray/data/datasource/file_based_datasource.py b/python/ray/data/datasource/file_based_datasource.py index fc92ca458067..df19fee459f4 100644 --- a/python/ray/data/datasource/file_based_datasource.py +++ b/python/ray/data/datasource/file_based_datasource.py @@ -259,7 +259,7 @@ def read_task_fn(): parallelism = min(parallelism, len(paths)) read_tasks = [] - for read_paths, file_sizes in zip( + for read_paths, _file_sizes in zip( np.array_split(paths, parallelism), np.array_split(file_sizes, parallelism) ): if len(read_paths) <= 0: @@ -269,7 +269,7 @@ def read_task_fn(): read_paths, self._schema, rows_per_file=self._rows_per_file(), - file_sizes=file_sizes, + file_sizes=_file_sizes, ) read_task_fn = create_read_task_fn(read_paths, self._NUM_THREADS_PER_TASK) diff --git a/python/ray/data/datasource/file_meta_provider.py b/python/ray/data/datasource/file_meta_provider.py index c6654e9e2708..b80fda7e367c 100644 --- a/python/ray/data/datasource/file_meta_provider.py +++ b/python/ray/data/datasource/file_meta_provider.py @@ -315,7 +315,7 @@ def _get_file_infos_common_path_prefix( filesystem: "pyarrow.fs.FileSystem", ignore_missing_paths: bool = False, ) -> Iterator[Tuple[str, int]]: - path_to_size = {path: None for path in paths} + path_to_size = dict.fromkeys(paths, None) for path, file_size in _get_file_infos( common_path, filesystem, ignore_missing_paths ): diff --git a/python/ray/data/iterator.py b/python/ray/data/iterator.py index 498950806a76..1c6c34f9ac50 100644 --- a/python/ray/data/iterator.py +++ b/python/ray/data/iterator.py @@ -225,11 +225,14 @@ def iter_rows( "the `prefetch_batches` parameter to specify the amount of prefetching " "in terms of batches instead of blocks.", DeprecationWarning, + stacklevel=2, ) iter_batch_args["prefetch_batches"] = prefetch_blocks if prefetch_batches != 1: warnings.warn( - "`prefetch_batches` is deprecated in Ray 2.12.", DeprecationWarning + "`prefetch_batches` is deprecated in Ray 2.12.", + DeprecationWarning, + stacklevel=2, ) batch_iterable = self.iter_batches(**iter_batch_args) diff --git a/python/ray/data/preprocessor.py b/python/ray/data/preprocessor.py index 9db73405a702..0b1b66978237 100644 --- a/python/ray/data/preprocessor.py +++ b/python/ray/data/preprocessor.py @@ -111,7 +111,8 @@ def fit(self, ds: "Dataset") -> "Preprocessor": warnings.warn( "`fit` has already been called on the preprocessor (or at least one " "contained preprocessors if this is a chain). " - "All previously fitted state will be overwritten!" + "All previously fitted state will be overwritten!", + stacklevel=2, ) fitted_ds = self._fit(ds) diff --git a/python/ray/data/read_api.py b/python/ray/data/read_api.py index b86920f28069..2b17352f2341 100644 --- a/python/ray/data/read_api.py +++ b/python/ray/data/read_api.py @@ -3170,6 +3170,7 @@ def _get_datasource_or_legacy_reader( "`Reader`, implement `Datasource.get_read_tasks` and " "`Datasource.estimate_inmemory_data_size`.", DeprecationWarning, + stacklevel=2, ) datasource_or_legacy_reader = ds.create_reader(**kwargs) else: diff --git a/python/ray/data/tests/test_image.py b/python/ray/data/tests/test_image.py index ade22cdec35f..7d96159e52f5 100644 --- a/python/ray/data/tests/test_image.py +++ b/python/ray/data/tests/test_image.py @@ -126,7 +126,7 @@ def test_mode( ): # "different-modes" contains 32x32 images with modes "CMYK", "L", and "RGB" ds = ray.data.read_images("example://image-datasets/different-modes", mode=mode) - assert all([record["image"].shape == expected_shape for record in ds.take()]) + assert all(record["image"].shape == expected_shape for record in ds.take()) def test_partitioning( self, ray_start_regular_shared, enable_automatic_tensor_extension_cast @@ -178,7 +178,7 @@ def test_random_shuffle(self, ray_start_regular_shared, restore_data_context): assert not all(all_paths_matched) # Check all files are output properly without missing one. assert all( - [file_paths == sorted(output_paths) for output_paths in output_paths_list] + file_paths == sorted(output_paths) for output_paths in output_paths_list ) def test_e2e_prediction(self, shutdown_only): diff --git a/python/ray/data/tests/test_map.py b/python/ray/data/tests/test_map.py index 4f058a9152f5..8b2d7e011742 100644 --- a/python/ray/data/tests/test_map.py +++ b/python/ray/data/tests/test_map.py @@ -82,7 +82,7 @@ def test_basic_actors(shutdown_only): def _all_actors_dead(): actor_table = ray.state.actors() actors = { - id: actor_info + actor_info["ActorClassName"]: actor_info for actor_info in actor_table.values() if actor_info["ActorClassName"] == _MapWorker.__name__ } diff --git a/python/ray/data/tests/test_numpy_support.py b/python/ray/data/tests/test_numpy_support.py index c14038918c0a..4f1f09aae97c 100644 --- a/python/ray/data/tests/test_numpy_support.py +++ b/python/ray/data/tests/test_numpy_support.py @@ -22,7 +22,7 @@ def do_map_batches(data): def assert_structure_equals(a, b): - assert type(a) == type(b), (type(a), type(b)) + assert isinstance(a, b), (type(a), type(b)) assert type(a[0]) == type(b[0]), (type(a[0]), type(b[0])) # noqa: E721 assert a.dtype == b.dtype assert a.shape == b.shape diff --git a/python/ray/data/tests/test_operators.py b/python/ray/data/tests/test_operators.py index 153fc38fb7eb..2e3628ce90d7 100644 --- a/python/ray/data/tests/test_operators.py +++ b/python/ray/data/tests/test_operators.py @@ -743,7 +743,7 @@ def test_union_operator(ray_start_regular_shared, preserve_order): assert union_op.get_next() == data2[0] assert union_op.get_next() == data1[1] - assert all([len(b) == 0 for b in union_op._input_buffers]) + assert all(len(b) == 0 for b in union_op._input_buffers) _take_outputs(union_op) union_op.all_inputs_done() diff --git a/python/ray/data/tests/test_resource_manager.py b/python/ray/data/tests/test_resource_manager.py index 376f947aa61b..a26fedcfb8bc 100644 --- a/python/ray/data/tests/test_resource_manager.py +++ b/python/ray/data/tests/test_resource_manager.py @@ -325,9 +325,9 @@ def test_basic(self, restore_data_context): o3 = mock_map_op(o2, incremental_resource_usage=ExecutionResources(1, 0, 10)) o4 = LimitOperator(1, o3) - op_usages = {op: ExecutionResources.zero() for op in [o1, o2, o3, o4]} - op_internal_usage = {op: 0 for op in [o1, o2, o3, o4]} - op_outputs_usages = {op: 0 for op in [o1, o2, o3, o4]} + op_usages = dict.fromkeys([o1, o2, o3, o4], ExecutionResources.zero()) + op_internal_usage = dict.fromkeys([o1, o2, o3, o4], 0) + op_outputs_usages = dict.fromkeys([o1, o2, o3, o4], 0) topo, _ = build_streaming_topology(o4, ExecutionOptions()) diff --git a/python/ray/data/tests/test_split.py b/python/ray/data/tests/test_split.py index 1af6596fdb13..b07d1c2224ee 100644 --- a/python/ray/data/tests/test_split.py +++ b/python/ray/data/tests/test_split.py @@ -113,7 +113,7 @@ def _test_equal_split_balanced(block_sizes, num_splits): assert len(split_counts) == num_splits expected_block_size = total_rows // num_splits # Check that all splits are the expected size. - assert all([count == expected_block_size for count in split_counts]) + assert all(count == expected_block_size for count in split_counts) expected_total_rows = sum(split_counts) # Check that the expected number of rows were dropped. assert total_rows - expected_total_rows == total_rows % num_splits @@ -594,9 +594,9 @@ def verify_splits(splits, blocks_by_split): for blocks, (block_refs, meta) in zip(blocks_by_split, splits): assert len(blocks) == len(block_refs) assert len(blocks) == len(meta) - for block, block_ref, meta in zip(blocks, block_refs, meta): + for block, block_ref, _meta in zip(blocks, block_refs, meta): assert list(ray.get(block_ref)["id"]) == block - assert meta.num_rows == len(block) + assert _meta.num_rows == len(block) def test_generate_global_split_results(ray_start_regular_shared): diff --git a/python/ray/data/tests/test_stats.py b/python/ray/data/tests/test_stats.py index 28caa6f6773d..cbb023a3c66e 100644 --- a/python/ray/data/tests/test_stats.py +++ b/python/ray/data/tests/test_stats.py @@ -354,7 +354,7 @@ def test_large_args_scheduling_strategy( f"Dataset throughput:\n" f" * Ray Data throughput: N rows/s\n" f" * Estimated single node throughput: N rows/s\n" - f"{gen_runtime_metrics_str(['ReadRange','MapBatches(dummy_map_batches)'], verbose_stats_logs)}" # noqa: E501 + f"{gen_runtime_metrics_str(['ReadRange', 'MapBatches(dummy_map_batches)'], verbose_stats_logs)}" # noqa: E501 ) assert canonicalize(stats) == expected_stats @@ -421,7 +421,7 @@ def test_dataset_stats_basic( f"Dataset throughput:\n" f" * Ray Data throughput: N rows/s\n" f" * Estimated single node throughput: N rows/s\n" - f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)','Map(dummy_map_batches)'], verbose_stats_logs)}" # noqa: E501 + f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)', 'Map(dummy_map_batches)'], verbose_stats_logs)}" # noqa: E501 ) for batch in ds.iter_batches(): @@ -473,7 +473,7 @@ def test_dataset_stats_basic( f"Dataset throughput:\n" f" * Ray Data throughput: N rows/s\n" f" * Estimated single node throughput: N rows/s\n" - f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)','Map(dummy_map_batches)'], verbose_stats_logs)}" # noqa: E501 + f"{gen_runtime_metrics_str(['ReadRange->MapBatches(dummy_map_batches)', 'Map(dummy_map_batches)'], verbose_stats_logs)}" # noqa: E501 ) diff --git a/python/ray/experimental/channel/nccl_group.py b/python/ray/experimental/channel/nccl_group.py index dcdfef10f163..b34fefbdf86a 100644 --- a/python/ray/experimental/channel/nccl_group.py +++ b/python/ray/experimental/channel/nccl_group.py @@ -10,7 +10,7 @@ ) if TYPE_CHECKING: - import cupy as cp + import cupy as cp # noqa: F401 import torch @@ -92,8 +92,7 @@ def __init__( self._cuda_stream: Optional["cp.cuda.ExternalStream"] = None if cuda_stream is not None: assert rank is not None, "NCCL actor has no rank assigned" - - import cupy as cp + import cupy as cp # noqa: F811 from ray.air._internal import torch_utils diff --git a/python/ray/experimental/raysort/tracing_utils.py b/python/ray/experimental/raysort/tracing_utils.py index 46a2bb315d6b..7a098144e510 100644 --- a/python/ray/experimental/raysort/tracing_utils.py +++ b/python/ray/experimental/raysort/tracing_utils.py @@ -76,7 +76,7 @@ def __init__( gauges: List[str], histograms: List[Tuple[str, List[int]]], ): - self.counts = {m: 0 for m in gauges} + self.counts = dict.fromkeys(gauges, 0) self.gauges = {m: Gauge(m) for m in gauges} self.reset_gauges() self.histograms = {m: Histogram(m, boundaries=b) for m, b in histograms} diff --git a/python/ray/serve/_private/deployment_scheduler.py b/python/ray/serve/_private/deployment_scheduler.py index 221e5aa9b1c6..7646cd28251f 100644 --- a/python/ray/serve/_private/deployment_scheduler.py +++ b/python/ray/serve/_private/deployment_scheduler.py @@ -52,7 +52,7 @@ def can_fit(self, other): def __eq__(self, other): keys = set(self.keys()) | set(other.keys()) - return all([self.get(k) == other.get(k) for k in keys]) + return all(self.get(k) == other.get(k) for k in keys) def __add__(self, other): keys = set(self.keys()) | set(other.keys()) diff --git a/python/ray/serve/config.py b/python/ray/serve/config.py index e7189a1ceb1b..e228da50e3f7 100644 --- a/python/ray/serve/config.py +++ b/python/ray/serve/config.py @@ -272,7 +272,8 @@ def warn_for_middlewares(cls, v, values): "Passing `middlewares` to HTTPOptions is deprecated and will be " "removed in a future version. Consider using the FastAPI integration " "to configure middlewares on your deployments: " - "https://docs.ray.io/en/latest/serve/http-guide.html#fastapi-http-deployments" # noqa 501 + "https://docs.ray.io/en/latest/serve/http-guide.html#fastapi-http-deployments", # noqa 501 + stacklevel=2, ) return v @@ -281,7 +282,8 @@ def warn_for_num_cpus(cls, v, values): if v: warnings.warn( "Passing `num_cpus` to HTTPOptions is deprecated and will be " - "removed in a future version." + "removed in a future version.", + stacklevel=2, ) return v diff --git a/python/ray/serve/context.py b/python/ray/serve/context.py index 32b56f8ffce1..f13df47958b2 100644 --- a/python/ray/serve/context.py +++ b/python/ray/serve/context.py @@ -179,7 +179,8 @@ class _RequestContext: _serve_request_context = contextvars.ContextVar( - "Serve internal request context variable", default=_RequestContext() + "Serve internal request context variable", + default=_RequestContext(), # noqa: B039 ) diff --git a/python/ray/serve/handle.py b/python/ray/serve/handle.py index 944adbd99a3b..2bc722e5096e 100644 --- a/python/ray/serve/handle.py +++ b/python/ray/serve/handle.py @@ -768,7 +768,8 @@ def options( if use_new_handle_api is not DEFAULT.VALUE: warnings.warn( "Setting `use_new_handle_api` no longer has any effect. " - "This argument will be removed in a future version." + "This argument will be removed in a future version.", + stacklevel=2, ) return self._options( diff --git a/python/ray/serve/multiplex.py b/python/ray/serve/multiplex.py index d0cdd942d245..a01adc226d87 100644 --- a/python/ray/serve/multiplex.py +++ b/python/ray/serve/multiplex.py @@ -173,7 +173,7 @@ async def load_model(self, model_id: str) -> Any: The user-constructed model object. """ - if type(model_id) != str: + if not isinstance(model_id, str): raise TypeError("The model ID must be a string.") if not model_id: diff --git a/python/ray/serve/tests/test_deploy.py b/python/ray/serve/tests/test_deploy.py index 7cde6f9f73a9..3559852cc853 100644 --- a/python/ray/serve/tests/test_deploy.py +++ b/python/ray/serve/tests/test_deploy.py @@ -350,7 +350,7 @@ def reconfigure(): signal.send.remote() ray.get(reconfigure_ref) - assert all([r.result() == 1 for r in responses]) + assert all(r.result() == 1 for r in responses) assert handle.remote().result() == 2 diff --git a/python/ray/serve/tests/test_grpc.py b/python/ray/serve/tests/test_grpc.py index 7ce7a7e986e2..d36b7512e74a 100644 --- a/python/ray/serve/tests/test_grpc.py +++ b/python/ray/serve/tests/test_grpc.py @@ -591,7 +591,7 @@ def Streaming( assert error_message == rpc_error.details() assert trailing_metadata in rpc_error.trailing_metadata() # request_id should always be set in the trailing metadata. - assert any([key == "request_id" for key, _ in rpc_error.trailing_metadata()]) + assert any(key == "request_id" for key, _ in rpc_error.trailing_metadata()) @pytest.mark.parametrize("streaming", [False, True]) diff --git a/python/ray/serve/tests/test_proxy_state.py b/python/ray/serve/tests/test_proxy_state.py index c84dea865872..cf4f9635b164 100644 --- a/python/ray/serve/tests/test_proxy_state.py +++ b/python/ray/serve/tests/test_proxy_state.py @@ -148,10 +148,8 @@ def _update_and_check_proxy_state_manager( proxy_state_manager.update(**kwargs) proxy_states = proxy_state_manager._proxy_states assert all( - [ - proxy_states[node_ids[idx]].status == statuses[idx] - for idx in range(len(node_ids)) - ] + proxy_states[node_ids[idx]].status == statuses[idx] + for idx in range(len(node_ids)) ), [proxy_state.status for proxy_state in proxy_states.values()] return True @@ -622,9 +620,7 @@ def test_proxy_state_manager_timing_out_on_start(number_of_worker_nodes, all_nod proxy_state._actor_proxy_wrapper.is_ready_response = False # Capture current proxy states (prior to updating) - prev_proxy_states = { - node_id: state for node_id, state in proxy_state_manager._proxy_states.items() - } + prev_proxy_states = dict(proxy_state_manager._proxy_states) # Trigger PSM to reconcile proxy_state_manager.update(proxy_nodes=node_ids) @@ -645,9 +641,7 @@ def test_proxy_state_manager_timing_out_on_start(number_of_worker_nodes, all_nod proxy_state._actor_proxy_wrapper.is_ready_response = True # Capture current proxy states again (prior to updating) - prev_proxy_states = { - node_id: state for node_id, state in proxy_state_manager._proxy_states.items() - } + prev_proxy_states = dict(proxy_state_manager._proxy_states) # Trigger PSM to reconcile proxy_state_manager.update(proxy_nodes=node_ids) diff --git a/python/ray/serve/tests/test_standalone.py b/python/ray/serve/tests/test_standalone.py index 0936432dc3aa..98517b04f9f6 100644 --- a/python/ray/serve/tests/test_standalone.py +++ b/python/ray/serve/tests/test_standalone.py @@ -2,6 +2,7 @@ The test file for all standalone tests that doesn't requires a shared Serve instance. """ + import logging import os import socket @@ -536,7 +537,7 @@ def test_http_head_only(ray_cluster): cpu_per_nodes = { r["CPU"] for r in ray._private.state.available_resources_per_node().values() } - assert cpu_per_nodes == {4, 4} + assert cpu_per_nodes == {4} def test_serve_shutdown(ray_shutdown): diff --git a/python/ray/serve/tests/test_standalone_3.py b/python/ray/serve/tests/test_standalone_3.py index 2b11becf41f3..09f5e049f953 100644 --- a/python/ray/serve/tests/test_standalone_3.py +++ b/python/ray/serve/tests/test_standalone_3.py @@ -461,7 +461,7 @@ def __call__(self): # Ensure the all resources are shutdown. wait_for_condition( lambda: all( - [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + actor["State"] == "DEAD" for actor in ray._private.state.actors().values() ) ) @@ -521,7 +521,7 @@ def __call__(self): # Ensure the all resources are shutdown gracefully. wait_for_condition( lambda: all( - [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + actor["State"] == "DEAD" for actor in ray._private.state.actors().values() ), ) @@ -554,7 +554,7 @@ def __call__(self): # Ensure the all resources are shutdown gracefully. wait_for_condition( lambda: all( - [actor["State"] == "DEAD" for actor in ray._private.state.actors().values()] + actor["State"] == "DEAD" for actor in ray._private.state.actors().values() ), ) diff --git a/python/ray/serve/tests/unit/test_deployment_scheduler.py b/python/ray/serve/tests/unit/test_deployment_scheduler.py index 003b7e004e3e..6631a743793b 100644 --- a/python/ray/serve/tests/unit/test_deployment_scheduler.py +++ b/python/ray/serve/tests/unit/test_deployment_scheduler.py @@ -304,7 +304,7 @@ def test_get_node_to_running_replicas(): # Test random case node_to_running_replicas = defaultdict(set) for i in range(40): - node_id = f"node{random.randint(0,5)}" + node_id = f"node{random.randint(0, 5)}" r_id = ReplicaID(f"r{i}", d_id) node_to_running_replicas[node_id].add(r_id) scheduler.on_replica_running(r_id, node_id) diff --git a/python/ray/serve/tests/unit/test_router.py b/python/ray/serve/tests/unit/test_router.py index 3870ac73f046..27bec9708fdf 100644 --- a/python/ray/serve/tests/unit/test_router.py +++ b/python/ray/serve/tests/unit/test_router.py @@ -402,11 +402,9 @@ async def test_max_queued_requests_no_limit( # Unblock the requests, now they should all get scheduled. fake_replica_scheduler.unblock_requests(100) assert all( - [ - not replica_result._is_generator_object - and replica_result._replica_id == r1_id - for replica_result in await asyncio.gather(*assign_request_tasks) - ] + not replica_result._is_generator_object + and replica_result._replica_id == r1_id + for replica_result in await asyncio.gather(*assign_request_tasks) ) async def test_max_queued_requests_limited( @@ -461,11 +459,9 @@ async def test_max_queued_requests_limited( # Unblock the requests, now they should all get scheduled. fake_replica_scheduler.unblock_requests(5) assert all( - [ - not replica_result._is_generator_object - and replica_result._replica_id == r1_id - for replica_result in await asyncio.gather(*assign_request_tasks) - ] + not replica_result._is_generator_object + and replica_result._replica_id == r1_id + for replica_result in await asyncio.gather(*assign_request_tasks) ) async def test_max_queued_requests_updated( @@ -533,11 +529,9 @@ async def test_max_queued_requests_updated( done, pending = await asyncio.wait(assign_request_tasks, timeout=0.01) assert len(pending) == 5 assert all( - [ - not replica_result._is_generator_object - and replica_result._replica_id == r1_id - for replica_result in await asyncio.gather(*done) - ] + not replica_result._is_generator_object + and replica_result._replica_id == r1_id + for replica_result in await asyncio.gather(*done) ) assign_request_tasks = list(pending) @@ -550,11 +544,9 @@ async def test_max_queued_requests_updated( # Unblock the requests, now they should all get scheduled. fake_replica_scheduler.unblock_requests(5) assert all( - [ - not replica_result._is_generator_object - and replica_result._replica_id == r1_id - for replica_result in await asyncio.gather(*assign_request_tasks) - ] + not replica_result._is_generator_object + and replica_result._replica_id == r1_id + for replica_result in await asyncio.gather(*assign_request_tasks) ) @pytest.mark.parametrize( diff --git a/python/ray/tests/autoscaler_test_utils.py b/python/ray/tests/autoscaler_test_utils.py index 178dc1d2cf4c..8cbcebd6ac2a 100644 --- a/python/ray/tests/autoscaler_test_utils.py +++ b/python/ray/tests/autoscaler_test_utils.py @@ -2,7 +2,7 @@ import threading from subprocess import CalledProcessError -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional from ray.autoscaler.node_provider import NodeProvider diff --git a/python/ray/tests/aws/utils/stubs.py b/python/ray/tests/aws/utils/stubs.py index 34012adcc7f1..3b44eb5704fb 100644 --- a/python/ray/tests/aws/utils/stubs.py +++ b/python/ray/tests/aws/utils/stubs.py @@ -280,8 +280,8 @@ def describe_instance_status_ok(ec2_client_stub, instance_ids): "Details": [{"Status": "passed", "Name": "reachability"}], }, } + for instance_id in instance_ids ] - for instance_id in instance_ids }, ) diff --git a/python/ray/tests/gcp/test_gcp_tpu_command_runner.py b/python/ray/tests/gcp/test_gcp_tpu_command_runner.py index 7febb6cdde72..19a710b3f9bf 100644 --- a/python/ray/tests/gcp/test_gcp_tpu_command_runner.py +++ b/python/ray/tests/gcp/test_gcp_tpu_command_runner.py @@ -242,7 +242,7 @@ def test_max_active_connections_env_var(): cmd_runner = TPUCommandRunner(**args) os.environ[ray_constants.RAY_TPU_MAX_CONCURRENT_CONNECTIONS_ENV_VAR] = "1" num_connections = cmd_runner.num_connections - assert type(num_connections) == int + assert isinstance(num_connections, int) assert num_connections == 1 diff --git a/python/ray/tests/modin/modin_test_utils.py b/python/ray/tests/modin/modin_test_utils.py index 96fcde04d2a9..34534619cda7 100644 --- a/python/ray/tests/modin/modin_test_utils.py +++ b/python/ray/tests/modin/modin_test_utils.py @@ -105,7 +105,7 @@ def df_equals(df1, df2): if isinstance(df1, pandas.DataFrame) and isinstance(df2, pandas.DataFrame): if (df1.empty and not df2.empty) or (df2.empty and not df1.empty): assert False, "One of the passed frames is empty, when other isn't" - elif df1.empty and df2.empty and type(df1) != type(df2): + elif df1.empty and df2.empty and not isinstance(df1, df2): assert ( False ), f"Empty frames have different types: {type(df1)} != {type(df2)}" diff --git a/python/ray/tests/test_basic_4.py b/python/ray/tests/test_basic_4.py index 269062e78951..e1d02268c88f 100644 --- a/python/ray/tests/test_basic_4.py +++ b/python/ray/tests/test_basic_4.py @@ -211,7 +211,7 @@ def verify(): def run(): try: verify() - except BaseException as e: + except BaseException as e: # noqa: B036 To avoid breaking change exc.append(e) import threading diff --git a/python/ray/tests/test_client_builder.py b/python/ray/tests/test_client_builder.py index a426e9502452..73fcecced195 100644 --- a/python/ray/tests/test_client_builder.py +++ b/python/ray/tests/test_client_builder.py @@ -51,7 +51,7 @@ def test_client(address): if address in ("local", None): assert isinstance(builder, client_builder._LocalClientBuilder) else: - assert type(builder) == client_builder.ClientBuilder + assert isinstance(builder, client_builder.ClientBuilder) assert builder.address == address.replace("ray://", "") diff --git a/python/ray/tests/test_failure_3.py b/python/ray/tests/test_failure_3.py index 1e2b3a56da1a..ba89c8399fb8 100644 --- a/python/ray/tests/test_failure_3.py +++ b/python/ray/tests/test_failure_3.py @@ -441,12 +441,12 @@ def task(): # Validate all children of the worker processes are in a sleeping state. processes = [psutil.Process(pid) for pid in pids] - assert all([proc.status() == psutil.STATUS_SLEEPING for proc in processes]) + assert all(proc.status() == psutil.STATUS_SLEEPING for proc in processes) # Valdiate children of worker process die after SIGINT. driver_proc.send_signal(signal.SIGINT) wait_for_condition( - condition_predictor=lambda: all([not proc.is_running() for proc in processes]), + condition_predictor=lambda: all(not proc.is_running() for proc in processes), timeout=30, ) @@ -543,7 +543,7 @@ def leaker_task(index): # Validate all children of the worker processes are in a sleeping state. processes = [psutil.Process(pid) for pid in pids] - assert all([proc.status() == psutil.STATUS_SLEEPING for proc in processes]) + assert all(proc.status() == psutil.STATUS_SLEEPING for proc in processes) # Obtain psutil handle for raylet process raylet_proc = [p for p in psutil.process_iter() if p.name() == "raylet"] @@ -556,7 +556,7 @@ def leaker_task(index): print("Waiting for child procs to die") wait_for_condition( - condition_predictor=lambda: all([not proc.is_running() for proc in processes]), + condition_predictor=lambda: all(not proc.is_running() for proc in processes), timeout=30, ) diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index bca9b83021de..544a71143fe8 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -815,7 +815,7 @@ def pid(self): # Wait until all data is updated in the replica leader_cli.set("_hole", "0") - wait_for_condition(lambda: all([b"_hole" in f.keys("*") for f in follower_cli])) + wait_for_condition(lambda: all(b"_hole" in f.keys("*") for f in follower_cli)) # Now kill pid leader_process = psutil.Process(pid=leader_pid) diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index d770af8c5347..551318f50ff5 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -107,8 +107,8 @@ def get_owner_info(node_ids): } # Force a global gc to clean up the object store. ray._private.internal_api.global_gc() - owner_stats = {n: 0 for n in node_ids} - primary_copy_stats = {n: 0 for n in node_ids} + owner_stats = dict.fromkeys(node_ids, 0) + primary_copy_stats = dict.fromkeys(node_ids, 0) for node_id in node_ids: node_stats = ray._private.internal_api.node_stats( diff --git a/python/ray/tests/test_metrics_agent_2.py b/python/ray/tests/test_metrics_agent_2.py index 6d1550d15b9e..a5abe94f47a2 100644 --- a/python/ray/tests/test_metrics_agent_2.py +++ b/python/ray/tests/test_metrics_agent_2.py @@ -350,8 +350,8 @@ def test_metrics_agent_proxy_record_and_export_from_workers_complicated( # Make sure the rest of metrics are still there because new metrics # are reported. - for i in range(i + 2, len(metrics)): - assert get_metric(f"{namespace}_test_{i}", agent_port) is not None, i + for j in range(i + 2, len(metrics)): + assert get_metric(f"{namespace}_test_{j}", agent_port) is not None, j i += 2 diff --git a/python/ray/tests/test_network_failure_e2e.py b/python/ray/tests/test_network_failure_e2e.py index 02e493acbc89..a8504648f217 100644 --- a/python/ray/tests/test_network_failure_e2e.py +++ b/python/ray/tests/test_network_failure_e2e.py @@ -64,7 +64,7 @@ def check_task_running(n=None): print("tasks_json:", json.dumps(tasks_json, indent=2)) if n is not None and n != len(tasks_json): return False - return all([task["state"] == "RUNNING" for task in tasks_json]) + return all(task["state"] == "RUNNING" for task in tasks_json) return False # list_task make sure all tasks are running @@ -102,7 +102,7 @@ def check_task_not_running(): if output.exit_code == 0: tasks_json = json.loads(output.output) print("tasks_json:", json.dumps(tasks_json, indent=2)) - return all([task["state"] != "RUNNING" for task in tasks_json]) + return all(task["state"] != "RUNNING" for task in tasks_json) return False # we set num_cpus=0 for head node. diff --git a/python/ray/tests/test_object_manager.py b/python/ray/tests/test_object_manager.py index 45d1f05964c4..5a376b9e33d8 100644 --- a/python/ray/tests/test_object_manager.py +++ b/python/ray/tests/test_object_manager.py @@ -14,7 +14,10 @@ multiprocessing.cpu_count() < 40 or ray._private.utils.get_system_memory() < 50 * 10**9 ): - warnings.warn("This test must be run on large machines.") + warnings.warn( + "This test must be run on large machines.", + stacklevel=2, + ) def create_cluster(num_nodes): @@ -134,7 +137,8 @@ def create_object(): if len(relevant_events) > num_nodes - 1: warnings.warn( "This object was transferred {} times, when only {} " - "transfers were required.".format(len(relevant_events), num_nodes - 1) + "transfers were required.".format(len(relevant_events), num_nodes - 1), + stacklevel=2, ) # Each object should not have been broadcast more than once from every # machine to every other machine. Also, a pair of machines should not diff --git a/python/ray/tests/test_runtime_env_packaging.py b/python/ray/tests/test_runtime_env_packaging.py index c2a82318f5a4..2313f30ede45 100644 --- a/python/ray/tests/test_runtime_env_packaging.py +++ b/python/ray/tests/test_runtime_env_packaging.py @@ -449,7 +449,9 @@ def test_unzip_package_with_multiple_top_level_dirs( unlink_zip, random_zip_file_without_top_level_dir, ): - """Test unzipping a package with multiple top level directories (not counting __MACOSX). + """ + Test unzipping a package with multiple top level directories + (not counting __MACOSX). Tests that we don't remove the top level directory, regardless of the value of remove_top_level_directory. diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py index b19979318817..ae9fdf6b1a8c 100644 --- a/python/ray/tests/test_serialization.py +++ b/python/ray/tests/test_serialization.py @@ -28,7 +28,7 @@ def is_named_tuple(cls): f = getattr(cls, "_fields", None) if not isinstance(f, tuple): return False - return all(type(n) == str for n in f) + return all(isinstance(n, str) for n in f) @pytest.mark.parametrize( @@ -95,8 +95,8 @@ def f(x): # TODO(rkn): The numpy dtypes currently come back as regular integers # or floats. if type(obj).__module__ != "numpy": - assert type(obj) == type(new_obj_1) - assert type(obj) == type(new_obj_2) + assert isinstance(obj, new_obj_1) + assert isinstance(obj, new_obj_2) @pytest.mark.parametrize( diff --git a/python/ray/tests/test_state_api_log.py b/python/ray/tests/test_state_api_log.py index 262c0c90b723..3c28fc119fbd 100644 --- a/python/ray/tests/test_state_api_log.py +++ b/python/ray/tests/test_state_api_log.py @@ -862,7 +862,7 @@ def verify_worker_logs(): "worker_out", "worker_err", ] - assert all([cat in logs for cat in worker_log_categories]) + assert all(cat in logs for cat in worker_log_categories) num_workers = len( list(filter(lambda w: w["worker_type"] == "WORKER", list_workers())) ) diff --git a/python/ray/tests/test_usage_stats.py b/python/ray/tests/test_usage_stats.py index 3cb680eddf0e..fa7570790864 100644 --- a/python/ray/tests/test_usage_stats.py +++ b/python/ray/tests/test_usage_stats.py @@ -1317,7 +1317,7 @@ def test_usage_report_disabled_ray_init_cluster( contents = f.readlines() break assert contents is not None - assert any(["Usage reporting is disabled" in c for c in contents]) + assert any("Usage reporting is disabled" in c for c in contents) def test_usage_report_disabled( @@ -1369,8 +1369,8 @@ def test_usage_report_disabled( contents = f.readlines() break assert contents is not None - assert any(["Usage reporting is disabled" in c for c in contents]) - assert all(["Usage report request failed" not in c for c in contents]) + assert any("Usage reporting is disabled" in c for c in contents) + assert all("Usage report request failed" not in c for c in contents) def test_usage_file_error_message(monkeypatch, ray_start_cluster, reset_usage_stats): diff --git a/python/ray/train/_internal/session.py b/python/ray/train/_internal/session.py index 7918675a22ea..1a7883fae80e 100644 --- a/python/ray/train/_internal/session.py +++ b/python/ray/train/_internal/session.py @@ -484,7 +484,8 @@ def get_dataset_shard( warnings.warn( "No dataset passed in. Returning None. Make sure to " "pass in a Dataset to Trainer.run to use this " - "function." + "function.", + stacklevel=2, ) elif isinstance(shard, dict): if not dataset_name: @@ -652,7 +653,8 @@ def wrapper(*args, **kwargs): warnings.warn( f"`{fn_name}` is meant to only be " "called inside a function that is executed by a Tuner" - f" or Trainer. Returning `{default_value}`." + f" or Trainer. Returning `{default_value}`.", + stacklevel=2, ) return default_value return fn(*args, **kwargs) diff --git a/python/ray/train/_internal/syncer.py b/python/ray/train/_internal/syncer.py index 4413e9245295..6461a66da4c5 100644 --- a/python/ray/train/_internal/syncer.py +++ b/python/ray/train/_internal/syncer.py @@ -275,6 +275,7 @@ def delete(self, remote_dir: str) -> bool: """ raise NotImplementedError + @abc.abstractmethod def retry(self): """Retry the last sync up, sync down, or delete command. @@ -283,6 +284,7 @@ def retry(self): """ pass + @abc.abstractmethod def wait(self, timeout: Optional[float] = None): """Wait for asynchronous sync command to finish. @@ -367,6 +369,7 @@ def reset(self): self.last_sync_up_time = float("-inf") self.last_sync_down_time = float("-inf") + @abc.abstractmethod def close(self): pass diff --git a/python/ray/train/base_trainer.py b/python/ray/train/base_trainer.py index 3cef395cbf63..879c8a65fedf 100644 --- a/python/ray/train/base_trainer.py +++ b/python/ray/train/base_trainer.py @@ -361,7 +361,8 @@ def training_loop(self): f"Invalid trainer type. You are attempting to restore a trainer of type" f" {trainer_cls} with `{cls.__name__}.restore`, " "which will most likely fail. " - f"Use `{trainer_cls.__name__}.restore` instead." + f"Use `{trainer_cls.__name__}.restore` instead.", + stacklevel=2, ) original_datasets = param_dict.pop("datasets", {}) @@ -514,6 +515,7 @@ def _validate_scaling_config(cls, scaling_config: ScalingConfig) -> ScalingConfi ) return scaling_config + @abc.abstractmethod def setup(self) -> None: """Called during fit() to perform initial setup on the Trainer. diff --git a/python/ray/train/examples/horovod/horovod_cifar_pbt_example.py b/python/ray/train/examples/horovod/horovod_cifar_pbt_example.py index ea6ffe36fc31..66f5c321cd59 100755 --- a/python/ray/train/examples/horovod/horovod_cifar_pbt_example.py +++ b/python/ray/train/examples/horovod/horovod_cifar_pbt_example.py @@ -52,7 +52,7 @@ def train_loop_per_worker(config): net.parameters(), lr=config["lr"], ) - epoch = 0 + epochs = 0 checkpoint = train.get_checkpoint() if checkpoint: @@ -62,7 +62,7 @@ def train_loop_per_worker(config): model_state = checkpoint_dict["model_state"] optimizer_state = checkpoint_dict["optimizer_state"] - epoch = checkpoint_dict["epoch"] + 1 + epochs = checkpoint_dict["epoch"] + 1 net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) @@ -88,7 +88,7 @@ def train_loop_per_worker(config): trainset, batch_size=int(config["batch_size"]), sampler=train_sampler ) - for epoch in range(epoch, 40): # loop over the dataset multiple times + for epoch in range(epochs, 40): # loop over the dataset multiple times running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(trainloader): diff --git a/python/ray/train/examples/pytorch/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py b/python/ray/train/examples/pytorch/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py index 28fe7461bc3c..f5cc53e95369 100644 --- a/python/ray/train/examples/pytorch/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py +++ b/python/ray/train/examples/pytorch/torch_data_prefetch_benchmark/auto_pipeline_for_host_to_device_data_transfer.py @@ -141,7 +141,10 @@ def train_linear(num_workers=1, num_hidden_layers=1, use_auto_transfer=True, epo ray.init(address=args.address) if not torch.cuda.is_available(): - warnings.warn("GPU is not available. Skip the test using auto pipeline.") + warnings.warn( + "GPU is not available. Skip the test using auto pipeline.", + stacklevel=2, + ) else: train_linear( num_workers=1, diff --git a/python/ray/train/tests/test_torch_trainer.py b/python/ray/train/tests/test_torch_trainer.py index d9a2c86d9d28..b08b932998f0 100644 --- a/python/ray/train/tests/test_torch_trainer.py +++ b/python/ray/train/tests/test_torch_trainer.py @@ -188,7 +188,8 @@ def single_worker_fail(): @pytest.mark.parametrize("num_gpus_per_worker", [0.5, 1, 2]) def test_tune_torch_get_device_gpu(num_gpus_per_worker): - """Tests if GPU ids are set correctly when running train concurrently in nested actors + """ + Tests if GPU ids are set correctly when running train concurrently in nested actors (for example when used with Tune). """ from ray.train import ScalingConfig diff --git a/python/ray/train/torch/torch_checkpoint.py b/python/ray/train/torch/torch_checkpoint.py index 7b6eeae30518..92044ce3fc8a 100644 --- a/python/ray/train/torch/torch_checkpoint.py +++ b/python/ray/train/torch/torch_checkpoint.py @@ -174,7 +174,8 @@ def get_model(self, model: Optional[torch.nn.Module] = None) -> torch.nn.Module: "Discarding provided `model` argument. If you are using " "TorchPredictor directly, you should do " "`TorchPredictor.from_checkpoint(checkpoint)` by removing kwargs " - "`model=`." + "`model=`.", + stacklevel=2, ) model = load_torch_model( saved_model=model_or_state_dict, model_definition=model diff --git a/python/ray/train/torch/xla/config.py b/python/ray/train/torch/xla/config.py index e965f9fc269a..35d41d082bf0 100644 --- a/python/ray/train/torch/xla/config.py +++ b/python/ray/train/torch/xla/config.py @@ -98,7 +98,7 @@ def _neuron_compile_extracted_graphs(): logger.info("Compiling extracted graphs on local rank0 worker") parallel_compile_workdir = ( - f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/" + f"/tmp/{os.environ.get('USER', 'no-user')}/parallel_compile_workdir/" ) if os.path.exists(parallel_compile_workdir): shutil.rmtree(parallel_compile_workdir) diff --git a/python/ray/tune/callback.py b/python/ray/tune/callback.py index 1295a0e61cef..9e3505a60f7d 100644 --- a/python/ray/tune/callback.py +++ b/python/ray/tune/callback.py @@ -380,6 +380,7 @@ def setup(self, **info): f"`{callback.__class__}` to match the method signature" " in `ray.tune.callback.Callback`.", FutureWarning, + stacklevel=2, ) callback.setup() else: diff --git a/python/ray/tune/examples/custom_func_checkpointing.py b/python/ray/tune/examples/custom_func_checkpointing.py index 333b71346bcb..c5fa5a5f1f59 100644 --- a/python/ray/tune/examples/custom_func_checkpointing.py +++ b/python/ray/tune/examples/custom_func_checkpointing.py @@ -27,14 +27,14 @@ def train_func(config): state = json.load(f) step = state["step"] + 1 - for step in range(step, 100): - intermediate_score = evaluation_fn(step, width, height) + for step_idx in range(step, 100): + intermediate_score = evaluation_fn(step_idx, width, height) with tempfile.TemporaryDirectory() as temp_checkpoint_dir: with open(os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w") as f: - json.dump({"step": step}, f) + json.dump({"step": step_idx}, f) train.report( - {"iterations": step, "mean_loss": intermediate_score}, + {"iterations": step_idx, "mean_loss": intermediate_score}, checkpoint=Checkpoint.from_directory(temp_checkpoint_dir), ) diff --git a/python/ray/tune/execution/placement_groups.py b/python/ray/tune/execution/placement_groups.py index 0848b147878d..a2c818cdb1ff 100644 --- a/python/ray/tune/execution/placement_groups.py +++ b/python/ray/tune/execution/placement_groups.py @@ -96,6 +96,7 @@ def __call__(self, *args, **kwargs): "Calling PlacementGroupFactory objects is deprecated. Use " "`to_placement_group()` instead.", DeprecationWarning, + stacklevel=2, ) kwargs.update(self._bound.kwargs) # Call with bounded *args and **kwargs @@ -114,7 +115,7 @@ def resource_dict_to_pg_factory(spec: Optional[Dict[str, float]] = None): memory = spec.pop("memory", 0.0) # If there is a custom_resources key, use as base for bundle - bundle = {k: v for k, v in spec.pop("custom_resources", {}).items()} + bundle = dict(spec.pop("custom_resources", {})) # Otherwise, consider all other keys as custom resources if not bundle: diff --git a/python/ray/tune/execution/tune_controller.py b/python/ray/tune/execution/tune_controller.py index bb482a80e6e0..c9dadf603fd5 100644 --- a/python/ray/tune/execution/tune_controller.py +++ b/python/ray/tune/execution/tune_controller.py @@ -188,7 +188,8 @@ def __init__( "mode as resources (such as Ray processes, " "file descriptors, and temporary files) may not be " "cleaned up properly. To use " - "a safer mode, use fail_fast=True." + "a safer mode, use fail_fast=True.", + stacklevel=2, ) else: raise ValueError( @@ -1202,9 +1203,6 @@ def _schedule_trial_task( tracked_actor = self._trial_to_actor[trial] - _on_result = None - _on_error = None - args = args or tuple() kwargs = kwargs or {} @@ -1228,6 +1226,8 @@ def _on_result(tracked_actor: TrackedActor, *args, **kwargs): else: raise TuneError(traceback.format_exc()) + else: + _on_result = None if on_error: def _on_error(tracked_actor: TrackedActor, exception: Exception): @@ -1253,6 +1253,8 @@ def _on_error(tracked_actor: TrackedActor, exception: Exception): else: raise TuneError(traceback.format_exc()) + else: + _on_error = None logger.debug(f"Future {method_name.upper()} SCHEDULED for trial {trial}") with _change_working_directory(trial): diff --git a/python/ray/tune/integration/pytorch_lightning.py b/python/ray/tune/integration/pytorch_lightning.py index 0ca554ebced8..4bf9d7a6cb3c 100644 --- a/python/ray/tune/integration/pytorch_lightning.py +++ b/python/ray/tune/integration/pytorch_lightning.py @@ -199,7 +199,8 @@ def __init__( "`ray.tune.integration.pytorch_lightning.TuneReportCallback` " "is deprecated. Use " "`ray.tune.integration.pytorch_lightning.TuneReportCheckpointCallback`" - " instead." + " instead.", + stacklevel=2, ) super(TuneReportCallback, self).__init__( metrics=metrics, save_checkpoints=False, on=on diff --git a/python/ray/tune/logger/logger.py b/python/ray/tune/logger/logger.py index ad14069c3c20..c8fe5dea31e3 100644 --- a/python/ray/tune/logger/logger.py +++ b/python/ray/tune/logger/logger.py @@ -51,6 +51,7 @@ def __init__(self, config: Dict, logdir: str, trial: Optional["Trial"] = None): self.trial = trial self._init() + @abc.abstractmethod def _init(self): pass @@ -59,16 +60,19 @@ def on_result(self, result): raise NotImplementedError + @abc.abstractmethod def update_config(self, config): """Updates the config for logger.""" pass + @abc.abstractmethod def close(self): """Releases all resources used by this logger.""" pass + @abc.abstractmethod def flush(self): """Flushes all disk writes to storage.""" diff --git a/python/ray/tune/progress_reporter.py b/python/ray/tune/progress_reporter.py index 2d83fb814408..e3e53ade5fd9 100644 --- a/python/ray/tune/progress_reporter.py +++ b/python/ray/tune/progress_reporter.py @@ -335,7 +335,8 @@ def _progress_str( self._sort_by_metric = False warnings.warn( "Both 'metric' and 'mode' must be set to be able " - "to sort by metric. No sorting is performed." + "to sort by metric. No sorting is performed.", + stacklevel=2, ) if not self._metrics_override: user_metrics = self._infer_user_metrics(trials, self._infer_limit) @@ -530,7 +531,8 @@ def __init__( "If this leads to unformatted output (e.g. like " "), consider passing " "a `CLIReporter` as the `progress_reporter` argument " - "to `train.RunConfig()` instead." + "to `train.RunConfig()` instead.", + stacklevel=2, ) self._overwrite = overwrite diff --git a/python/ray/tune/requirements-dev.txt b/python/ray/tune/requirements-dev.txt index e4432a5471c6..86e44349e881 100644 --- a/python/ray/tune/requirements-dev.txt +++ b/python/ray/tune/requirements-dev.txt @@ -1,4 +1,4 @@ -flake8==3.9.1 +flake8==7.1.1 flake8-quotes gym>=0.21.0,<0.24.0 scikit-image diff --git a/python/ray/tune/schedulers/pbt.py b/python/ray/tune/schedulers/pbt.py index 0c389f76dcd0..60ad3689a25a 100644 --- a/python/ray/tune/schedulers/pbt.py +++ b/python/ray/tune/schedulers/pbt.py @@ -481,7 +481,8 @@ def on_trial_add(self, tune_controller: "TuneController", trial: Trial): "Using `CheckpointConfig.num_to_keep <= 2` with PBT can lead to " "restoration problems when checkpoint are deleted too early for " "other trials to exploit them. If this happens, increase the value " - "of `num_to_keep`." + "of `num_to_keep`.", + stacklevel=2, ) self._trial_state[trial] = _PBTTrialState(trial) diff --git a/python/ray/tune/schedulers/resource_changing_scheduler.py b/python/ray/tune/schedulers/resource_changing_scheduler.py index 24d437cf892f..5b7f0ac7dfcc 100644 --- a/python/ray/tune/schedulers/resource_changing_scheduler.py +++ b/python/ray/tune/schedulers/resource_changing_scheduler.py @@ -683,7 +683,8 @@ def __init__( warnings.warn( "`resources_allocation_function` is None. No resource " "requirements will be changed at any time. Pass a " - "correctly defined function to enable functionality." + "correctly defined function to enable functionality.", + stacklevel=2, ) self._resources_allocation_function = resources_allocation_function self._base_scheduler = base_scheduler or FIFOScheduler() diff --git a/python/ray/tune/search/basic_variant.py b/python/ray/tune/search/basic_variant.py index c9a59bd95ffd..149edbb99556 100644 --- a/python/ray/tune/search/basic_variant.py +++ b/python/ray/tune/search/basic_variant.py @@ -339,7 +339,8 @@ def add_configurations( "exceeds the serialization threshold " f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is " "disabled. To fix this, reduce the number of " - "dimensions/size of the provided grid search." + "dimensions/size of the provided grid search.", + stacklevel=2, ) previous_samples = self._total_samples diff --git a/python/ray/tune/search/optuna/optuna_search.py b/python/ray/tune/search/optuna/optuna_search.py index f1656039a4af..ba79f4715912 100644 --- a/python/ray/tune/search/optuna/optuna_search.py +++ b/python/ray/tune/search/optuna/optuna_search.py @@ -456,7 +456,8 @@ def _suggest_from_define_by_run_func( f"took {time_taken} seconds to " "run. Ensure that actual computation, training takes " "place inside Tune's train functions or Trainables " - "passed to `tune.Tuner()`." + "passed to `tune.Tuner()`.", + stacklevel=2, ) if ret is not None: if not isinstance(ret, dict): @@ -589,9 +590,7 @@ def add_evaluated_point( ot_trial_state = OptunaTrialState.PRUNED if intermediate_values: - intermediate_values_dict = { - i: value for i, value in enumerate(intermediate_values) - } + intermediate_values_dict = dict(enumerate(intermediate_values)) else: intermediate_values_dict = None diff --git a/python/ray/tune/search/searcher.py b/python/ray/tune/search/searcher.py index 55f32af56e05..6762c6950d6c 100644 --- a/python/ray/tune/search/searcher.py +++ b/python/ray/tune/search/searcher.py @@ -276,7 +276,8 @@ def trial_to_points(trial: Trial) -> Dict[str, Any]: if not any_trial_had_metric: warnings.warn( "No completed trial returned the specified metric. " - "Make sure the name you have passed is correct. " + "Make sure the name you have passed is correct. ", + stacklevel=2, ) def save(self, checkpoint_path: str): diff --git a/python/ray/tune/stopper/stopper.py b/python/ray/tune/stopper/stopper.py index 1c2ff60fd1f1..fa4197cd6450 100644 --- a/python/ray/tune/stopper/stopper.py +++ b/python/ray/tune/stopper/stopper.py @@ -44,10 +44,12 @@ class Stopper(abc.ABC): """ + @abc.abstractmethod def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool: """Returns true if the trial should be terminated given the result.""" raise NotImplementedError + @abc.abstractmethod def stop_all(self) -> bool: """Returns true if the experiment should be terminated.""" raise NotImplementedError diff --git a/python/ray/tune/tests/execution/utils.py b/python/ray/tune/tests/execution/utils.py index c9e5cd54ff65..f410b704cab0 100644 --- a/python/ray/tune/tests/execution/utils.py +++ b/python/ray/tune/tests/execution/utils.py @@ -87,7 +87,7 @@ def next(self, timeout: Optional[Union[int, float]] = None) -> None: pass def set_num_pending(self, num_pending: int): - self._pending_actors_to_attrs = {i: None for i in range(num_pending)} + self._pending_actors_to_attrs = dict.fromkeys(range(num_pending), None) class _FakeResourceUpdater(_ResourceUpdater): diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index 2bc68ff49a59..b7c7adfe0ab9 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -203,7 +203,8 @@ def f(config): EXPECTED_BEST_2 = "Current best trial: 00004 with metric_1=2.0 and parameters={'a': 4}" -EXPECTED_SORT_RESULT_UNSORTED = """Number of trials: 5 (1 PENDING, 1 RUNNING, 3 TERMINATED) +EXPECTED_SORT_RESULT_UNSORTED = """\ +Number of trials: 5 (1 PENDING, 1 RUNNING, 3 TERMINATED) +--------------+------------+-------+-----+------------+ | Trial name | status | loc | a | metric_1 | |--------------+------------+-------+-----+------------| @@ -225,7 +226,8 @@ def f(config): +--------------+------------+-------+-----+------------+ ... 1 more trials not shown (1 TERMINATED)""" -EXPECTED_NESTED_SORT_RESULT = """Number of trials: 5 (1 PENDING, 1 RUNNING, 3 TERMINATED) +EXPECTED_NESTED_SORT_RESULT = """\ +Number of trials: 5 (1 PENDING, 1 RUNNING, 3 TERMINATED) +--------------+------------+-------+-----+-------------------+ | Trial name | status | loc | a | nested/metric_2 | |--------------+------------+-------+-----+-------------------| @@ -279,7 +281,8 @@ def f(config): VERBOSE_TRIAL_WITH_ONCE_RESULT = "Result for train_fn_xxxxx_00001" VERBOSE_TRIAL_WITH_ONCE_COMPLETED = "Trial train_fn_xxxxx_00001 completed." -VERBOSE_TRIAL_DETAIL = """+-------------------+----------+-------------------+----------+ +VERBOSE_TRIAL_DETAIL = """\ ++-------------------+----------+-------------------+----------+ | Trial name | status | loc | do | |-------------------+----------+-------------------+----------| | train_fn_xxxxx_00000 | RUNNING | 123.123.123.123:1 | complete |""" @@ -423,7 +426,7 @@ def report(self, *args, **kwargs): reporter = TestReporter() analysis = tune.run(test, num_samples=3, progress_reporter=reporter, verbose=3) - found = {k: False for k in test_result} + found = dict.fromkeys(test_result, False) for output in reporter._output: for key in test_result: if key in output: diff --git a/python/ray/tune/tests/test_sample.py b/python/ray/tune/tests/test_sample.py index 8e1d271ad40f..a0c62aa44e5a 100644 --- a/python/ray/tune/tests/test_sample.py +++ b/python/ray/tune/tests/test_sample.py @@ -29,7 +29,7 @@ def assertDictAlmostEqual(a, b): assert k in b, f"Key {k} not found in {b}" w = b[k] - assert type(v) == type(w), f"Type {type(v)} is not {type(w)}" + assert isinstance(v, w), f"Type {type(v)} is not {type(w)}" if isinstance(v, dict): assert assertDictAlmostEqual(v, w), f"Subdict {v} != {w}" diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 879bf53f7d58..8f821fb6dd42 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -258,6 +258,7 @@ def _ray_auto_init(entrypoint: str): class _Config(abc.ABC): + @abc.abstractmethod def to_dict(self) -> dict: """Converts this configuration to a dict format.""" raise NotImplementedError @@ -648,6 +649,7 @@ def run( "keep_checkpoints_num is deprecated and will be removed. " "use checkpoint_config.num_to_keep instead.", DeprecationWarning, + stacklevel=2, ) checkpoint_config.num_to_keep = keep_checkpoints_num if checkpoint_score_attr is not None: @@ -655,6 +657,7 @@ def run( "checkpoint_score_attr is deprecated and will be removed. " "use checkpoint_config.checkpoint_score_attribute instead.", DeprecationWarning, + stacklevel=2, ) if checkpoint_score_attr.startswith("min-"): @@ -663,6 +666,7 @@ def run( "order is deprecated. Use CheckpointConfig.checkpoint_score_order " "instead", DeprecationWarning, + stacklevel=2, ) checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr[4:] checkpoint_config.checkpoint_score_order = "min" @@ -676,6 +680,7 @@ def run( "checkpoint_freq is deprecated and will be removed. " "use checkpoint_config.checkpoint_frequency instead.", DeprecationWarning, + stacklevel=2, ) checkpoint_config.checkpoint_frequency = checkpoint_freq if checkpoint_at_end: @@ -683,6 +688,7 @@ def run( "checkpoint_at_end is deprecated and will be removed. " "use checkpoint_config.checkpoint_at_end instead.", DeprecationWarning, + stacklevel=2, ) checkpoint_config.checkpoint_at_end = checkpoint_at_end @@ -719,7 +725,8 @@ def run( f"TUNE_RESULT_BUFFER_LENGTH is set " f"({env_result_buffer_length}). This can lead to undesired " f"and faulty behavior, so the buffer length was forcibly set " - f"to 1 instead." + f"to 1 instead.", + stacklevel=2, ) os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1" @@ -729,7 +736,8 @@ def run( ): warnings.warn( "Consider boosting PBT performance by enabling `reuse_actors` as " - "well as implementing `reset_config` for Trainable." + "well as implementing `reset_config` for Trainable.", + stacklevel=2, ) # Before experiments are created, we first clean up the passed in diff --git a/python/ray/tune/tuner.py b/python/ray/tune/tuner.py index 86d7cae55375..c0c0f61eda5b 100644 --- a/python/ray/tune/tuner.py +++ b/python/ray/tune/tuner.py @@ -42,7 +42,8 @@ @PublicAPI(stability="beta") class Tuner: - """Tuner is the recommended way of launching hyperparameter tuning jobs with Ray Tune. + """ + Tuner is the recommended way of launching hyperparameter tuning jobs with Ray Tune. Args: trainable: The trainable to be tuned. diff --git a/python/ray/util/client/common.py b/python/ray/util/client/common.py index caf5572c69d0..0c1b107c1f64 100644 --- a/python/ray/util/client/common.py +++ b/python/ray/util/client/common.py @@ -707,7 +707,7 @@ def _get_client_id_from_context(context: Any) -> str: Get `client_id` from gRPC metadata. If the `client_id` is not present, this function logs an error and sets the status_code. """ - metadata = {k: v for k, v in context.invocation_metadata()} + metadata = dict(context.invocation_metadata()) client_id = metadata.get("client_id") or "" if client_id == "": logger.error("Client connecting with no client_id") diff --git a/python/ray/util/client/dataclient.py b/python/ray/util/client/dataclient.py index 5ce08117087d..7cb1c07cf98b 100644 --- a/python/ray/util/client/dataclient.py +++ b/python/ray/util/client/dataclient.py @@ -57,6 +57,7 @@ def chunk_put(req: ray_client_pb2.DataRequest): "Documentation for doing this can be found here: " "https://docs.ray.io/en/latest/handling-dependencies.html#remote-uris", UserWarning, + stacklevel=2, ) total_chunks = math.ceil(total_size / OBJECT_TRANSFER_CHUNK_SIZE) for chunk_id in range(0, total_chunks): @@ -148,6 +149,7 @@ def __call__(self, response: Union[ray_client_pb2.DataResponse, Exception]) -> b "be slow. Consider serializing the object to a file and " "using rsync or S3 instead.", UserWarning, + stacklevel=2, ) chunk_data = get_resp.data chunk_id = get_resp.chunk_id diff --git a/python/ray/util/client/server/dataservicer.py b/python/ray/util/client/server/dataservicer.py index 9ce816856e4d..af06b8902785 100644 --- a/python/ray/util/client/server/dataservicer.py +++ b/python/ray/util/client/server/dataservicer.py @@ -32,7 +32,7 @@ def _get_reconnecting_from_context(context: Any) -> bool: """ Get `reconnecting` from gRPC metadata, or False if missing. """ - metadata = {k: v for k, v in context.invocation_metadata()} + metadata = dict(context.invocation_metadata()) val = metadata.get("reconnecting") if val is None or val not in ("True", "False"): logger.error( @@ -155,7 +155,7 @@ def Datapath(self, request_iterator, context): start_time = time.time() # set to True if client shuts down gracefully cleanup_requested = False - metadata = {k: v for k, v in context.invocation_metadata()} + metadata = dict(context.invocation_metadata()) client_id = metadata.get("client_id") if client_id is None: logger.error("Client connecting with no client_id") diff --git a/python/ray/util/client/worker.py b/python/ray/util/client/worker.py index 57acede6bd4d..f6a1dca79231 100644 --- a/python/ray/util/client/worker.py +++ b/python/ray/util/client/worker.py @@ -2,6 +2,7 @@ It implements the Ray API functions that are forwarded through grpc calls to the server. """ + import base64 import json import logging @@ -255,7 +256,8 @@ def _connect_channel(self, reconnecting=False) -> None: "the Ray Client port on the head node is reachable " "from your local machine. See https://docs.ray.io/en" "/latest/cluster/ray-client.html#step-2-check-ports for " - "more information." + "more information.", + stacklevel=2, ) raise ConnectionError("ray client connection timeout") @@ -631,6 +633,7 @@ def populate_ids(resp: Union[ray_client_pb2.DataResponse, Exception]) -> None: "document, available here: " f"{DESIGN_PATTERN_LARGE_OBJECTS_LINK}", UserWarning, + stacklevel=2, ) return id_futures @@ -724,7 +727,7 @@ def get_cluster_info( resp = self.server.ClusterInfo(req, timeout=timeout, metadata=self.metadata) if resp.WhichOneof("response_type") == "resource_table": # translate from a proto map to a python dict - output_dict = {k: v for k, v in resp.resource_table.table.items()} + output_dict = dict(resp.resource_table.table) return output_dict elif resp.WhichOneof("response_type") == "runtime_context": return resp.runtime_context diff --git a/python/ray/util/collective/collective_group/base_collective_group.py b/python/ray/util/collective/collective_group/base_collective_group.py index 1272d946f0a3..3babeeb306dd 100644 --- a/python/ray/util/collective/collective_group/base_collective_group.py +++ b/python/ray/util/collective/collective_group/base_collective_group.py @@ -1,4 +1,5 @@ """Abstract class for collective groups.""" + from abc import ABCMeta from abc import abstractmethod @@ -40,6 +41,7 @@ def group_name(self): """Return the group name of this group.""" return self._group_name + @abstractmethod def destroy_group(self): """GC the communicators.""" pass diff --git a/python/ray/util/dask/optimizations.py b/python/ray/util/dask/optimizations.py index 1f1f910f07b1..c27381796464 100644 --- a/python/ray/util/dask/optimizations.py +++ b/python/ray/util/dask/optimizations.py @@ -139,7 +139,8 @@ def dataframe_optimize(dsk, keys, **kwargs): "Custom dataframe shuffle optimization only works on " "dask>=2020.12.0, you are on version " f"{dask.__version__}, please upgrade Dask." - "Falling back to default dataframe optimizer." + "Falling back to default dataframe optimizer.", + stacklevel=2, ) return optimize(dsk, keys, **kwargs) diff --git a/python/ray/util/dask/scheduler.py b/python/ray/util/dask/scheduler.py index f3dbca063b5a..4cec21e028b7 100644 --- a/python/ray/util/dask/scheduler.py +++ b/python/ray/util/dask/scheduler.py @@ -298,7 +298,7 @@ def _rayify_task_wrapper( id = get_id() result = dumps((result, id)) failed = False - except BaseException as e: + except BaseException as e: # noqa: B036 To avoid breaking change result = pack_exception(e, dumps) failed = True return key, result, failed diff --git a/python/ray/util/dask/scheduler_utils.py b/python/ray/util/dask/scheduler_utils.py index efb7b18bd911..9cfd87ee7a68 100644 --- a/python/ray/util/dask/scheduler_utils.py +++ b/python/ray/util/dask/scheduler_utils.py @@ -106,7 +106,8 @@ def execute_task(key, task_info, dumps, loads, get_id, pack_exception): id = get_id() result = dumps((result, id)) failed = False - except BaseException as e: + + except BaseException as e: # noqa: B036 To avoid breaking change result = pack_exception(e, dumps) failed = True return key, result, failed diff --git a/python/ray/util/spark/cluster_init.py b/python/ray/util/spark/cluster_init.py index c9852fad3828..5bc636fca617 100644 --- a/python/ray/util/spark/cluster_init.py +++ b/python/ray/util/spark/cluster_init.py @@ -851,6 +851,7 @@ def _setup_ray_cluster_internal( "'num_cpus_per_node' argument is deprecated, please use " "'num_cpus_worker_node' argument instead.", DeprecationWarning, + stacklevel=2, ) if "num_gpus_per_node" in kwargs: @@ -864,6 +865,7 @@ def _setup_ray_cluster_internal( "'num_gpus_per_node' argument is deprecated, please use " "'num_gpus_worker_node' argument instead.", DeprecationWarning, + stacklevel=2, ) if "object_store_memory_per_node" in kwargs: @@ -878,6 +880,7 @@ def _setup_ray_cluster_internal( "'object_store_memory_per_node' argument is deprecated, please use " "'object_store_memory_worker_node' argument instead.", DeprecationWarning, + stacklevel=2, ) # Environment configurations within the Spark Session that dictate how many cpus diff --git a/python/ray/util/state/api.py b/python/ray/util/state/api.py index acbe750d906a..cbce31d4e4b2 100644 --- a/python/ray/util/state/api.py +++ b/python/ray/util/state/api.py @@ -361,7 +361,7 @@ def _print_api_warning( if warn_data_source_not_available: warning_msgs = api_response.get("partial_failure_warning", None) if warning_msgs: - warnings.warn(warning_msgs) + warnings.warn(warning_msgs, stacklevel=2) if warn_data_truncation: # Print warnings if data is truncated at the data source. @@ -382,6 +382,7 @@ def _print_api_warning( f"Max of {num_after_truncation} entries are retrieved " "from data source to prevent over-sized payloads." ), + stacklevel=2, ) if warn_limit: @@ -397,6 +398,7 @@ def _print_api_warning( "the amount of data to return or " "setting a higher limit with `--limit` to see all data. " ), + stacklevel=2, ) if warn_server_side_warnings: @@ -404,7 +406,7 @@ def _print_api_warning( warnings_to_print = api_response.get("warnings", []) if warnings_to_print: for warning_to_print in warnings_to_print: - warnings.warn(warning_to_print) + warnings.warn(warning_to_print, stacklevel=2) def _raise_on_missing_output(self, resource: StateResource, api_response: dict): """Raise an exception when the API resopnse contains a missing output. diff --git a/python/ray/util/state/common.py b/python/ray/util/state/common.py index 686d5355af4a..94b7d3e40d3a 100644 --- a/python/ray/util/state/common.py +++ b/python/ray/util/state/common.py @@ -2,7 +2,6 @@ import json import logging import sys -from abc import ABC from dataclasses import asdict, field, fields from enum import Enum, unique from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -216,7 +215,7 @@ def state_column(*, filterable: bool, detail: bool = False, format_fn=None, **kw return field(**kwargs) -class StateSchema(ABC): +class StateSchema: """Schema class for Ray resource abstraction. The child class must be dataclass. All child classes diff --git a/python/ray/util/state/util.py b/python/ray/util/state/util.py index 16a5221e458f..e50f46a223f5 100644 --- a/python/ray/util/state/util.py +++ b/python/ray/util/state/util.py @@ -56,6 +56,7 @@ def record_deprecated_state_api_import(): "instead. Importing from `ray.experimental` will be deprecated in " "future releases. ", DeprecationWarning, + stacklevel=2, ) record_extra_usage_tag(TagKey.EXPERIMENTAL_STATE_API_IMPORT, "1") diff --git a/python/ray/workflow/http_event_provider.py b/python/ray/workflow/http_event_provider.py index 5a25fc97c1ee..5ed0b3b87376 100644 --- a/python/ray/workflow/http_event_provider.py +++ b/python/ray/workflow/http_event_provider.py @@ -31,7 +31,9 @@ def __init__(self, workflow_id: str, what_happened: str): @serve.deployment(num_replicas=1) @serve.ingress(app) class HTTPEventProvider: - """HTTPEventProvider is defined to be a Ray Serve deployment with route_prefix='/event', + """ + HTTPEventProvider is defined to be a Ray Serve deployment + with route_prefix='/event', which will receive external events via an HTTP endpoint. It supports FastAPI, e.g. post. It responds to both poll_for_event() and event_checkpointed() from an HTTPListener instance. diff --git a/python/ray/workflow/tests/test_event_resume_after_crash.py b/python/ray/workflow/tests/test_event_resume_after_crash.py index 2068d74b586f..8280ef22c845 100644 --- a/python/ray/workflow/tests/test_event_resume_after_crash.py +++ b/python/ray/workflow/tests/test_event_resume_after_crash.py @@ -27,8 +27,9 @@ indirect=True, ) def test_cluster_crash_before_checkpoint(workflow_start_regular_shared_serve): - """If the cluster crashed before the event was checkpointed, after the cluster restarted - and the workflow resumed, the new event message is processed by the workflow. + """If the cluster crashed before the event was checkpointed, + after the cluster restarted and the workflow resumed, + the new event message is processed by the workflow. """ class CustomHTTPListener(HTTPListener): diff --git a/python/requirements/lint-requirements.txt b/python/requirements/lint-requirements.txt index 3a0889a28853..d04bc2c0fe6b 100644 --- a/python/requirements/lint-requirements.txt +++ b/python/requirements/lint-requirements.txt @@ -1,9 +1,9 @@ clang-format==12.0.1 docutils -flake8==3.9.1 -flake8-comprehensions==3.10.1 -flake8-quotes==2.0.0 -flake8-bugbear==21.9.2 +flake8==7.1.1 +flake8-quotes==3.4.0 +flake8-comprehensions==3.15.0 +flake8-bugbear==24.8.19 mypy==1.7.0 mypy-extensions==1.0.0 types-PyYAML==6.0.12.2 diff --git a/python/requirements_compiled.txt b/python/requirements_compiled.txt index a1043afc5b51..739e2a52e242 100644 --- a/python/requirements_compiled.txt +++ b/python/requirements_compiled.txt @@ -521,17 +521,17 @@ filelock==3.13.1 # torch # transformers # virtualenv -flake8==3.9.1 +flake8==7.1.1 # via # -r /ray/ci/../python/requirements/lint-requirements.txt # flake8-bugbear # flake8-comprehensions # flake8-quotes -flake8-bugbear==21.9.2 +flake8-bugbear==24.8.19 # via -r /ray/ci/../python/requirements/lint-requirements.txt -flake8-comprehensions==3.10.1 +flake8-comprehensions==3.15.0 # via -r /ray/ci/../python/requirements/lint-requirements.txt -flake8-quotes==2.0.0 +flake8-quotes==3.4.0 # via -r /ray/ci/../python/requirements/lint-requirements.txt flask==2.1.3 # via @@ -1048,7 +1048,7 @@ matplotlib-inline==0.1.6 # via # ipykernel # ipython -mccabe==0.6.1 +mccabe==0.7.0 # via flake8 mdit-py-plugins==0.3.5 # via diff --git a/python/setup.py b/python/setup.py index 35cd57a70f20..046feeb3359f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -568,6 +568,7 @@ def build(build_python, build_java, build_cpp): "Setting BAZEL_LIMIT_CPUS is deprecated and will be removed in a future" " version. Please use BAZEL_ARGS instead.", FutureWarning, + stacklevel=2, ) if is_automated_build: @@ -772,51 +773,55 @@ def has_ext_modules(self): if os.path.isdir(build_dir): shutil.rmtree(build_dir) -setuptools.setup( - name=setup_spec.name, - version=setup_spec.version, - author="Ray Team", - author_email="ray-dev@googlegroups.com", - description=(setup_spec.description), - long_description=io.open( - os.path.join(ROOT_DIR, os.path.pardir, "README.rst"), "r", encoding="utf-8" - ).read(), - url="https://github.com/ray-project/ray", - keywords=( - "ray distributed parallel machine-learning hyperparameter-tuning" - "reinforcement-learning deep-learning serving python" - ), - python_requires=">=3.9", - classifiers=[ - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - ], - packages=setup_spec.get_packages(), - cmdclass={"build_ext": build_ext}, - # The BinaryDistribution argument triggers build_ext. - distclass=BinaryDistribution, - install_requires=setup_spec.install_requires, - setup_requires=["cython >= 0.29.32", "wheel"], - extras_require=setup_spec.extras, - entry_points={ - "console_scripts": [ - "ray=ray.scripts.scripts:main", - "rllib=ray.rllib.scripts:cli [rllib]", - "tune=ray.tune.cli.scripts:cli", - "serve=ray.serve.scripts:cli", - ] - }, - package_data={ - "ray": ["includes/*.pxd", "*.pxd", "data/_internal/logging.yaml"], - }, - include_package_data=True, - exclude_package_data={ - # Empty string means "any package". - # Therefore, exclude BUILD from every package: - "": ["BUILD"], - }, - zip_safe=False, - license="Apache 2.0", -) if __name__ == "__main__" else None +( + setuptools.setup( + name=setup_spec.name, + version=setup_spec.version, + author="Ray Team", + author_email="ray-dev@googlegroups.com", + description=(setup_spec.description), + long_description=io.open( + os.path.join(ROOT_DIR, os.path.pardir, "README.rst"), "r", encoding="utf-8" + ).read(), + url="https://github.com/ray-project/ray", + keywords=( + "ray distributed parallel machine-learning hyperparameter-tuning" + "reinforcement-learning deep-learning serving python" + ), + python_requires=">=3.9", + classifiers=[ + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + ], + packages=setup_spec.get_packages(), + cmdclass={"build_ext": build_ext}, + # The BinaryDistribution argument triggers build_ext. + distclass=BinaryDistribution, + install_requires=setup_spec.install_requires, + setup_requires=["cython >= 0.29.32", "wheel"], + extras_require=setup_spec.extras, + entry_points={ + "console_scripts": [ + "ray=ray.scripts.scripts:main", + "rllib=ray.rllib.scripts:cli [rllib]", + "tune=ray.tune.cli.scripts:cli", + "serve=ray.serve.scripts:cli", + ] + }, + package_data={ + "ray": ["includes/*.pxd", "*.pxd", "data/_internal/logging.yaml"], + }, + include_package_data=True, + exclude_package_data={ + # Empty string means "any package". + # Therefore, exclude BUILD from every package: + "": ["BUILD"], + }, + zip_safe=False, + license="Apache 2.0", + ) + if __name__ == "__main__" + else None +) diff --git a/release/air_tests/air_benchmarks/mlperf-train/resnet50_ray_air.py b/release/air_tests/air_benchmarks/mlperf-train/resnet50_ray_air.py index a1c98c7f680d..cf05aa5a1656 100644 --- a/release/air_tests/air_benchmarks/mlperf-train/resnet50_ray_air.py +++ b/release/air_tests/air_benchmarks/mlperf-train/resnet50_ray_air.py @@ -529,7 +529,7 @@ def append_to_test_output_json(path, metrics): available_disk_space = statvfs.f_bavail * statvfs.f_frsize expected_disk_usage = args.num_images_per_epoch * APPROX_PREPROCESS_IMAGE_BYTES print(f"Available disk space: {available_disk_space / 1e9}GB") - print(f"Expected disk usage: {expected_disk_usage/ 1e9}GB") + print(f"Expected disk usage: {expected_disk_usage / 1e9}GB") disk_error_expected = expected_disk_usage > available_disk_space * 0.8 datasets = {} diff --git a/release/air_tests/horovod/workloads/horovod_tune_test.py b/release/air_tests/horovod/workloads/horovod_tune_test.py index 3f009063c08a..86ab31461420 100755 --- a/release/air_tests/horovod/workloads/horovod_tune_test.py +++ b/release/air_tests/horovod/workloads/horovod_tune_test.py @@ -52,7 +52,7 @@ def train_loop_per_worker(config): net.parameters(), lr=config["lr"], ) - epoch = 0 + epochs = 0 checkpoint = train.get_checkpoint() if checkpoint: @@ -62,7 +62,7 @@ def train_loop_per_worker(config): optimizer_state = torch.load( checkpoint_dir / "optim.pt", map_location="cpu" ) - epoch = torch.load(checkpoint_dir / "extra_state.pt")["epoch"] + 1 + epochs = torch.load(checkpoint_dir / "extra_state.pt")["epoch"] + 1 net.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) @@ -88,7 +88,7 @@ def train_loop_per_worker(config): trainset, batch_size=int(config["batch_size"]), sampler=train_sampler ) - for epoch in range(epoch, 40): # loop over the dataset multiple times + for epoch in range(epochs, 40): # loop over the dataset multiple times running_loss = 0.0 epoch_steps = 0 for i, data in enumerate(trainloader): diff --git a/release/nightly_tests/dask_on_ray/dask_on_ray_sort.py b/release/nightly_tests/dask_on_ray/dask_on_ray_sort.py index 5733af65f75b..d137cc2c1fe7 100644 --- a/release/nightly_tests/dask_on_ray/dask_on_ray_sort.py +++ b/release/nightly_tests/dask_on_ray/dask_on_ray_sort.py @@ -269,6 +269,6 @@ def trial( "dask_nthreads": args.dask_nthreads, "dask_memlimit": args.dask_memlimit, } - for output in output: - row["duration"] = output + for out in output: + row["duration"] = out writer.writerow(row) diff --git a/release/nightly_tests/stress_tests/test_placement_group.py b/release/nightly_tests/stress_tests/test_placement_group.py index 93705bf22c5a..43165d768a53 100644 --- a/release/nightly_tests/stress_tests/test_placement_group.py +++ b/release/nightly_tests/stress_tests/test_placement_group.py @@ -186,7 +186,7 @@ def pg_launcher(pre_created_pgs, num_pgs_to_create): ) print( "Avg placement group removing time: " - f"{total_removing_time / total_trial* 1000} ms" + f"{total_removing_time / total_trial * 1000} ms" ) print("PASSED.") diff --git a/release/ray_release/cluster_manager/cluster_manager.py b/release/ray_release/cluster_manager/cluster_manager.py index fac34cc00eb6..393c7684ff5a 100644 --- a/release/ray_release/cluster_manager/cluster_manager.py +++ b/release/ray_release/cluster_manager/cluster_manager.py @@ -114,12 +114,15 @@ def _annotate_cluster_compute( ) return cluster_compute + @abc.abstractmethod def build_configs(self, timeout: float = 30.0): raise NotImplementedError + @abc.abstractmethod def delete_configs(self): raise NotImplementedError + @abc.abstractmethod def start_cluster(self, timeout: float = 600.0): raise NotImplementedError @@ -129,9 +132,11 @@ def terminate_cluster(self, wait: bool = False): except Exception as e: logger.exception(f"Could not terminate cluster: {e}") + @abc.abstractmethod def terminate_cluster_ex(self, wait: bool = False): raise NotImplementedError + @abc.abstractmethod def get_cluster_address(self) -> str: raise NotImplementedError diff --git a/release/ray_release/file_manager/file_manager.py b/release/ray_release/file_manager/file_manager.py index a9e6594a4af8..2c278fcb6d5f 100644 --- a/release/ray_release/file_manager/file_manager.py +++ b/release/ray_release/file_manager/file_manager.py @@ -8,6 +8,7 @@ class FileManager(abc.ABC): def __init__(self, cluster_manager: ClusterManager): self.cluster_manager = cluster_manager + @abc.abstractmethod def upload(self, source: Optional[str] = None, target: Optional[str] = None): """Upload source to target. @@ -15,6 +16,7 @@ def upload(self, source: Optional[str] = None, target: Optional[str] = None): """ raise NotImplementedError + @abc.abstractmethod def download(self, source: str, target: str): """Download source_dir to target_dir.""" raise NotImplementedError diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index 5c2584d2ed82..bfb1047214c9 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -168,9 +168,7 @@ def possibly_masked_mean(data_): # Calculate our own expected values (to then compare against the # agent's loss output). module = algo.learner_group._learner.module[DEFAULT_MODULE_ID].unwrapped() - fwd_out = module.forward_train( - {k: v for k, v in batch[DEFAULT_MODULE_ID].items()} - ) + fwd_out = module.forward_train(dict(batch[DEFAULT_MODULE_ID])) advantages = ( batch[DEFAULT_MODULE_ID][Columns.VALUE_TARGETS].detach().cpu().numpy() - module.compute_values(batch[DEFAULT_MODULE_ID]).detach().cpu().numpy() @@ -199,7 +197,7 @@ def possibly_masked_mean(data_): # calculation above). total_loss = algo.learner_group._learner.compute_loss_for_module( module_id=DEFAULT_MODULE_ID, - batch={k: v for k, v in batch[DEFAULT_MODULE_ID].items()}, + batch=dict(batch[DEFAULT_MODULE_ID]), fwd_out=fwd_out, config=config, ) diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index b9f0eba34ec8..74bf7ff3bee0 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -138,9 +138,7 @@ def test_sac_dict_obs_order(self): # Dict space .sample() returns an ordered dict. # Make sure the keys in samples are ordered differently. - dict_samples = [ - {k: v for k, v in reversed(dict_space.sample().items())} for _ in range(10) - ] + dict_samples = [dict(reversed(dict_space.sample().items())) for _ in range(10)] class NestedDictEnv(gym.Env): def __init__(self): diff --git a/rllib/connectors/action/lambdas.py b/rllib/connectors/action/lambdas.py index 3bf862dd834d..658640f8dcbc 100644 --- a/rllib/connectors/action/lambdas.py +++ b/rllib/connectors/action/lambdas.py @@ -19,7 +19,8 @@ def register_lambda_action_connector( name: str, fn: Callable[[TensorStructType, StateBatches, Dict], PolicyOutputType] ) -> Type[ActionConnector]: - """A util to register any function transforming PolicyOutputType as an ActionConnector. + """ + A util to register any function transforming PolicyOutputType as an ActionConnector. The only requirement is that fn should take actions, states, and fetches as input, and return transformed actions, states, and fetches. diff --git a/rllib/connectors/action/pipeline.py b/rllib/connectors/action/pipeline.py index b96296e75842..0674773df5fd 100644 --- a/rllib/connectors/action/pipeline.py +++ b/rllib/connectors/action/pipeline.py @@ -44,8 +44,8 @@ def to_state(self): @staticmethod def from_state(ctx: ConnectorContext, params: Any): - assert ( - type(params) == list + assert isinstance( + params, list ), "ActionConnectorPipeline takes a list of connector params." connectors = [] for state in params: diff --git a/rllib/connectors/agent/clip_reward.py b/rllib/connectors/agent/clip_reward.py index 9d55e4aea24a..71715828e3f5 100644 --- a/rllib/connectors/agent/clip_reward.py +++ b/rllib/connectors/agent/clip_reward.py @@ -24,8 +24,8 @@ def __init__(self, ctx: ConnectorContext, sign=False, limit=None): def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: d = ac_data.data - assert ( - type(d) == dict + assert isinstance( + d, dict ), "Single agent data must be of type Dict[str, TensorStructType]" if SampleBatch.REWARDS not in d: diff --git a/rllib/connectors/agent/mean_std_filter.py b/rllib/connectors/agent/mean_std_filter.py index 64de99164267..abdf43cb497c 100644 --- a/rllib/connectors/agent/mean_std_filter.py +++ b/rllib/connectors/agent/mean_std_filter.py @@ -51,8 +51,8 @@ def __init__( def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: d = ac_data.data - assert ( - type(d) == dict + assert isinstance( + d, dict ), "Single agent data must be of type Dict[str, TensorStructType]" if SampleBatch.OBS in d: d[SampleBatch.OBS] = self.filter( diff --git a/rllib/connectors/agent/obs_preproc.py b/rllib/connectors/agent/obs_preproc.py index 41b23745c05a..478d4368b951 100644 --- a/rllib/connectors/agent/obs_preproc.py +++ b/rllib/connectors/agent/obs_preproc.py @@ -44,7 +44,7 @@ def is_identity(self): def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType: d = ac_data.data - assert type(d) == dict, ( + assert isinstance(d, dict), ( "Single agent data must be of type Dict[str, TensorStructType] but is of " "type {}".format(type(d)) ) diff --git a/rllib/connectors/agent/pipeline.py b/rllib/connectors/agent/pipeline.py index 0113c1c45887..1f5ad48ace70 100644 --- a/rllib/connectors/agent/pipeline.py +++ b/rllib/connectors/agent/pipeline.py @@ -55,8 +55,8 @@ def to_state(self): @staticmethod def from_state(ctx: ConnectorContext, params: List[Any]): - assert ( - type(params) == list + assert isinstance( + params, list ), "AgentConnectorPipeline takes a list of connector params." connectors = [] for state in params: diff --git a/rllib/connectors/connector.py b/rllib/connectors/connector.py index 80c5003085b1..0fe4c65be2e0 100644 --- a/rllib/connectors/connector.py +++ b/rllib/connectors/connector.py @@ -121,6 +121,7 @@ def to_state(self) -> Tuple[str, Any]: return NotImplementedError @staticmethod + @abc.abstractmethod def from_state(self, ctx: ConnectorContext, params: Any) -> "Connector": """De-serialize a JSON params back into a Connector. @@ -333,7 +334,7 @@ def transform(self, ac_data: ActionConnectorDataType) -> ActionConnectorDataType @OldAPIStack -class ConnectorPipeline(abc.ABC): +class ConnectorPipeline: """Utility class for quick manipulation of a connector pipeline.""" def __init__(self, ctx: ConnectorContext, connectors: List[Connector]): diff --git a/rllib/connectors/tests/test_agent.py b/rllib/connectors/tests/test_agent.py index 6deb2dc29077..e3bcd9acde4c 100644 --- a/rllib/connectors/tests/test_agent.py +++ b/rllib/connectors/tests/test_agent.py @@ -300,8 +300,11 @@ def test_vr_connector_respects_training_or_inference_vr_flags(self): check(sample_batch, sample_batch_expected) def test_vr_connector_shift_by_one(self): - """Test that the ViewRequirementAgentConnector can handle shift by one correctly and - can ignore future referencing view_requirements to respect causality""" + """ + Test that the ViewRequirementAgentConnector can handle shift + by one correctly and can ignore future referencing + view_requirements to respect causality + """ view_rq_dict = { "state": ViewRequirement("obs"), "next_state": ViewRequirement( diff --git a/rllib/connectors/util.py b/rllib/connectors/util.py index e0ffbcea29d6..674da2e46356 100644 --- a/rllib/connectors/util.py +++ b/rllib/connectors/util.py @@ -54,7 +54,7 @@ def get_agent_connectors_from_config( clip_rewards = __clip_rewards(config) if clip_rewards is True: connectors.append(ClipRewardAgentConnector(ctx, sign=True)) - elif type(clip_rewards) == float: + elif isinstance(clip_rewards, float): connectors.append(ClipRewardAgentConnector(ctx, limit=abs(clip_rewards))) if __preprocessing_enabled(config): diff --git a/rllib/env/remote_base_env.py b/rllib/env/remote_base_env.py index b9e388d50bcf..e569c318680a 100644 --- a/rllib/env/remote_base_env.py +++ b/rllib/env/remote_base_env.py @@ -250,7 +250,7 @@ def poll( # observations and infos: Set rewards, terminateds, and truncateds to # dummy values. if rew is None: - rew = {agent_id: 0 for agent_id in ob.keys()} + rew = dict.fromkeys(ob.keys(), 0) terminated = {"__all__": False} truncated = {"__all__": False} @@ -328,7 +328,7 @@ def stop(self) -> None: @override(BaseEnv) def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]: if as_dict: - return {env_id: actor for env_id, actor in enumerate(self.actors)} + return dict(enumerate(self.actors)) return self.actors @property diff --git a/rllib/env/tests/test_multi_agent_episode.py b/rllib/env/tests/test_multi_agent_episode.py index a0189b092339..b2e0da2b21a2 100644 --- a/rllib/env/tests/test_multi_agent_episode.py +++ b/rllib/env/tests/test_multi_agent_episode.py @@ -47,8 +47,8 @@ def reset( np.array(sorted(self._agent_ids)), num_agents_step, replace=False ) # Initialize observations. - init_obs = {agent_id: 0 for agent_id in agents_step} - init_info = {agent_id: {} for agent_id in agents_step} + init_obs = dict.fromkeys(agents_step, 0) + init_info = dict.fromkeys(agents_step, {}) # Reset all alive agents to all agents. self._agents_alive = set(self._agent_ids) @@ -77,7 +77,7 @@ def step( # Initialize observations. obs = {agent_id: self.t for agent_id in agents_step} info = {agent_id: {} for agent_id in agents_step} - reward = {agent_id: 1.0 for agent_id in agents_step} + reward = dict.fromkeys(agents_step, 1.0) # Add also agents without observations. reward.update( { @@ -91,9 +91,9 @@ def step( # Use tha last terminateds/truncateds. is_truncated = {"__all__": False} - is_truncated.update({agent_id: False for agent_id in agents_step}) + is_truncated.update(dict.fromkeys(agents_step, False)) is_terminated = {"__all__": False} - is_terminated.update({agent_id: False for agent_id in agents_step}) + is_terminated.update(dict.fromkeys(agents_step, False)) if self.t == 50: # Let agent 1 die. @@ -120,7 +120,7 @@ def step( # Truncate the episode if too long. if self.t >= 200 and self.truncate: is_truncated["__all__"] = True - is_truncated.update({agent_id: True for agent_id in agents_step}) + is_truncated.update(dict.fromkeys(agents_step, True)) return obs, reward, is_terminated, is_truncated, info @@ -489,9 +489,9 @@ def test_add_env_step(self): observation = {"agent_1": 3, "agent_2": 3} infos = {"agent_1": {}, "agent_2": {}} - terminated = {k: False for k in observation.keys()} + terminated = dict.fromkeys(observation.keys(), False) terminated.update({"__all__": False}) - truncated = {k: False for k in observation.keys()} + truncated = dict.fromkeys(observation.keys(), False) truncated.update({"__all__": False}) episode.add_env_step( observations=observation, @@ -2703,9 +2703,9 @@ def test_cut(self): # add this to the buffer and to the global reward history. reward = {"agent_1": 2.0, "agent_2": 2.0, "agent_3": 2.0, "agent_5": 2.0} info = {"agent_1": {}, "agent_2": {}} - terminateds = {k: False for k in observation.keys()} + terminateds = dict.fromkeys(observation.keys(), False) terminateds.update({"__all__": False}) - truncateds = {k: False for k in observation.keys()} + truncateds = dict.fromkeys(observation.keys(), False) truncateds.update({"__all__": False}) episode_1.add_env_step( observations=observation, @@ -3470,10 +3470,7 @@ def _mock_multi_agent_records_from_env( # In the other case we need at least the last observations for the next # actions. else: - obs = { - agent_id: agent_obs - for agent_id, agent_obs in episode.get_observations(-1).items() - } + obs = dict(episode.get_observations(-1).items()) # Sample `size` many records. done_agents = set() diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index c3e0896ba05e..b1da92dd0cad 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -498,10 +498,7 @@ def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], di if not as_dict: return self.vector_env.get_sub_environments() else: - return { - _id: env - for _id, env in enumerate(self.vector_env.get_sub_environments()) - } + return dict(enumerate(self.vector_env.get_sub_environments())) @override(BaseEnv) def try_render(self, env_id: Optional[EnvID] = None) -> None: diff --git a/rllib/env/wrappers/dm_control_wrapper.py b/rllib/env/wrappers/dm_control_wrapper.py index 4c0a7407b9ae..577b499306b6 100644 --- a/rllib/env/wrappers/dm_control_wrapper.py +++ b/rllib/env/wrappers/dm_control_wrapper.py @@ -47,10 +47,10 @@ def _spec_to_box(spec): def extract_min_max(s): assert s.dtype == np.float64 or s.dtype == np.float32 dim = np.int_(np.prod(s.shape)) - if type(s) == specs.Array: + if isinstance(s, specs.Array): bound = np.inf * np.ones(dim, dtype=np.float32) return -bound, bound - elif type(s) == specs.BoundedArray: + elif isinstance(s, specs.BoundedArray): zeros = np.zeros(dim, dtype=np.float32) return s.minimum + zeros, s.maximum + zeros diff --git a/rllib/env/wrappers/multi_agent_env_compatibility.py b/rllib/env/wrappers/multi_agent_env_compatibility.py index fc8efeda0834..2e048b2363fe 100644 --- a/rllib/env/wrappers/multi_agent_env_compatibility.py +++ b/rllib/env/wrappers/multi_agent_env_compatibility.py @@ -61,7 +61,7 @@ def step( obs, rewards, terminateds, infos = self.env.step(action) # Truncated should always be False by default. - truncateds = {k: False for k in terminateds.keys()} + truncateds = dict.fromkeys(terminateds.keys(), False) return obs, rewards, terminateds, truncateds, infos diff --git a/rllib/env/wrappers/open_spiel.py b/rllib/env/wrappers/open_spiel.py index 1bc7ba119e68..fe0875e3f53d 100644 --- a/rllib/env/wrappers/open_spiel.py +++ b/rllib/env/wrappers/open_spiel.py @@ -62,7 +62,7 @@ def step(self, action): penalties[curr_player] = -0.1 # Compile rewards dict. - rewards = {ag: r for ag, r in enumerate(self.state.returns())} + rewards = dict(enumerate(self.state.returns())) # Simultaneous game. else: assert self.state.current_player() == -2 @@ -74,19 +74,18 @@ def step(self, action): # Compile rewards dict and add the accumulated penalties # (for taking invalid actions). - rewards = {ag: r for ag, r in enumerate(self.state.returns())} + rewards = dict(enumerate(self.state.returns())) for ag, penalty in penalties.items(): rewards[ag] += penalty # Are we done? is_terminated = self.state.is_terminal() - terminateds = dict( - {ag: is_terminated for ag in range(self.num_agents)}, - **{"__all__": is_terminated} - ) - truncateds = dict( - {ag: False for ag in range(self.num_agents)}, **{"__all__": False} - ) + terminateds = { + **dict.fromkeys(range(self.num_agents), is_terminated), + "__all__": is_terminated, + } + + truncateds = {**dict.fromkeys(range(self.num_agents), False), "__all__": False} return obs, rewards, terminateds, truncateds, {} diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 45f0f910af92..7083a0099c15 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -183,7 +183,7 @@ def step( obs, rewards, terminateds, - dict({"__all__": True}, **{agent_id: True for agent_id in all_agents}), + {"__all__": True, **dict.fromkeys(all_agents, True)}, infos, ) diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 68dc3f638657..d0ab18dc2658 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -158,7 +158,7 @@ def __init__( # Agents to collect data from for the next forward pass (per policy). self.forward_pass_agent_keys = {pid: [] for pid in self.policy_map.keys()} - self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()} + self.forward_pass_size = dict.fromkeys(self.policy_map.keys(), 0) # Maps episode ID to the (non-built) env steps taken in this episode. self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int) @@ -411,9 +411,11 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType] return SampleBatch( input_dict, - seq_lens=np.ones(batch_size, dtype=np.int32) - if "state_in_0" in input_dict - else None, + seq_lens=( + np.ones(batch_size, dtype=np.int32) + if "state_in_0" in input_dict + else None + ), ) @override(SampleCollector) diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 9e34fd237ee0..5a9078e59666 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -55,8 +55,7 @@ from ray.util.debug import log_once if TYPE_CHECKING: - from gymnasium.envs.classic_control.rendering import SimpleImageViewer - + from gymnasium.envs.classic_control.rendering import SimpleImageViewer # noqa: F401 from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -451,7 +450,7 @@ def _new_episode(env_id): # ImageViewer not defined yet, try to create one. if simple_image_viewer is None: try: - from gymnasium.envs.classic_control.rendering import ( + from gymnasium.envs.classic_control.rendering import ( # noqa: F811,E501 - to avoid brekaing change SimpleImageViewer, ) @@ -691,12 +690,16 @@ def _process_observations( agent_id, filtered_obs, agent_infos, - None - if last_observation is None - else episode.rnn_state_for(agent_id), - None - if last_observation is None - else episode.last_action_for(agent_id), + ( + None + if last_observation is None + else episode.rnn_state_for(agent_id) + ), + ( + None + if last_observation is None + else episode.last_action_for(agent_id) + ), rewards[env_id].get(agent_id, 0.0), ) to_eval[policy_id].append(item) @@ -1030,10 +1033,10 @@ def _process_policy_eval_results( episode: Episode = active_episodes[env_id] _assert_episode_not_faulty(episode) episode._set_rnn_state( - agent_id, tree.map_structure(lambda x: x[i], rnn_out_cols) + agent_id, tree.map_structure(lambda x, i=i: x[i], rnn_out_cols) ) episode._set_last_extra_action_outs( - agent_id, tree.map_structure(lambda x: x[i], extra_action_out_cols) + agent_id, tree.map_structure(lambda x, i=i: x[i], extra_action_out_cols) ) if env_id in off_policy_actions and agent_id in off_policy_actions[env_id]: episode._set_last_action(agent_id, off_policy_actions[env_id][agent_id]) diff --git a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py index 5efbead7e66f..4cfa6d34d67d 100644 --- a/rllib/examples/rl_modules/classes/modelv2_to_rlm.py +++ b/rllib/examples/rl_modules/classes/modelv2_to_rlm.py @@ -187,7 +187,7 @@ def compute_values(self, batch: Dict[str, Any], embeddings: Optional[Any] = None def get_initial_state(self): """Converts the initial state list of ModelV2 into a dict (new API stack).""" init_state_list = self._model_v2.get_initial_state() - return {i: s for i, s in enumerate(init_state_list)} + return dict(enumerate(init_state_list)) def _translate_dist_class(self, old_dist_class): map_ = { diff --git a/rllib/models/torch/mingpt.py b/rllib/models/torch/mingpt.py index 4bf54aa2fe8e..7e24cfdc730a 100644 --- a/rllib/models/torch/mingpt.py +++ b/rllib/models/torch/mingpt.py @@ -193,7 +193,7 @@ def configure_gpt_optimizer( no_decay.add(fpn) # validate that we considered every parameter - param_dict = {pn: p for pn, p in model.named_parameters()} + param_dict = dict(model.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay assert ( diff --git a/rllib/offline/offline_evaluator.py b/rllib/offline/offline_evaluator.py index 60b87ff1296d..099b8d84b2b9 100644 --- a/rllib/offline/offline_evaluator.py +++ b/rllib/offline/offline_evaluator.py @@ -55,6 +55,7 @@ def train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: """ return {} + @abc.abstractmethod @ExperimentalAPI def estimate_on_dataset( self, @@ -62,7 +63,6 @@ def estimate_on_dataset( *, n_parallelism: int = os.cpu_count(), ) -> Dict[str, Any]: - """Calculates the estimate of the metrics based on the given offline dataset. Typically, the dataset is passed through only once via n_parallel tasks in diff --git a/rllib/policy/dynamic_tf_policy.py b/rllib/policy/dynamic_tf_policy.py index ac40205de94a..9645faf6e08f 100644 --- a/rllib/policy/dynamic_tf_policy.py +++ b/rllib/policy/dynamic_tf_policy.py @@ -782,7 +782,7 @@ def _initialize_loss_from_dummy_batch( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]} ) - self._loss_input_dict.update({k: v for k, v in train_batch.items()}) + self._loss_input_dict.update(dict(train_batch)) if log_once("loss_init"): logger.debug( diff --git a/rllib/policy/dynamic_tf_policy_v2.py b/rllib/policy/dynamic_tf_policy_v2.py index e2ad3d6da0ab..fde1a3ff67e5 100644 --- a/rllib/policy/dynamic_tf_policy_v2.py +++ b/rllib/policy/dynamic_tf_policy_v2.py @@ -736,7 +736,7 @@ def _initialize_loss_from_dummy_batch( {SampleBatch.SEQ_LENS: train_batch[SampleBatch.SEQ_LENS]} ) - self._loss_input_dict.update({k: v for k, v in train_batch.items()}) + self._loss_input_dict.update(dict(train_batch)) if log_once("loss_init"): logger.debug( diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index cceca81dd5d4..87848d9fa64e 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -334,9 +334,11 @@ def compute_actions_from_input_dict( self.global_timestep += ( len(obs_batch) if isinstance(obs_batch, list) - else len(input_dict) - if isinstance(input_dict, SampleBatch) - else obs_batch.shape[0] + else ( + len(input_dict) + if isinstance(input_dict, SampleBatch) + else obs_batch.shape[0] + ) ) return fetched @@ -423,7 +425,7 @@ def compute_log_likelihoods( self._state_inputs, state_batches ) ) - builder.add_feed_dict({k: v for k, v in zip(self._state_inputs, state_batches)}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) if state_batches: builder.add_feed_dict({self._seq_lens: np.ones(len(obs_batch))}) # Prev-a and r. @@ -765,9 +767,11 @@ def _initialize_loss( ) with tf1.control_dependencies(self._update_ops): self._apply_op = self.build_apply_op( - optimizer=self._optimizers - if self.config["_tf_policy_handles_more_than_one_loss"] - else self._optimizer, + optimizer=( + self._optimizers + if self.config["_tf_policy_handles_more_than_one_loss"] + else self._optimizer + ), grads_and_vars=self._grads_and_vars, ) diff --git a/rllib/utils/minibatch_utils.py b/rllib/utils/minibatch_utils.py index e27b5a7782ba..97e1e5c3b060 100644 --- a/rllib/utils/minibatch_utils.py +++ b/rllib/utils/minibatch_utils.py @@ -76,9 +76,9 @@ def __init__( self._shuffle_batch_per_epoch = shuffle_batch_per_epoch # mapping from module_id to the start index of the batch - self._start = {mid: 0 for mid in batch.policy_batches.keys()} + self._start = dict.fromkeys(batch.policy_batches.keys(), 0) # mapping from module_id to the number of epochs covered for each module_id - self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()} + self._num_covered_epochs = dict.fromkeys(batch.policy_batches.keys(), 0) self._uses_new_env_runners = _uses_new_env_runners diff --git a/rllib/utils/replay_buffers/episode_replay_buffer.py b/rllib/utils/replay_buffers/episode_replay_buffer.py index 4a3309b301d9..0d93992fac56 100644 --- a/rllib/utils/replay_buffers/episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/episode_replay_buffer.py @@ -117,7 +117,8 @@ def __len__(self) -> int: @override(ReplayBufferInterface) def add(self, episodes: Union[List["SingleAgentEpisode"], "SingleAgentEpisode"]): - """Converts the incoming SampleBatch into a number of SingleAgentEpisode objects. + """ + Converts the incoming SampleBatch into a number of SingleAgentEpisode objects. Then adds these episodes to the internal deque. """ diff --git a/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py index a463467ee286..7e954df9cba2 100644 --- a/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_multi_agent_mixin_replay_buffer.py @@ -63,7 +63,7 @@ def test_mixin_sampling_episodes(self): for _ in range(20): buffer.add(batch) sample = buffer.sample(2) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) # One sample in the episode does not belong the the episode on thus # gets dropped. Full episodes are of length two. @@ -88,7 +88,7 @@ def test_mixin_sampling_sequences(self): for _ in range(400): buffer.add(batch) sample = buffer.sample(10) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 2 * len(batch), delta=0.1) @@ -113,7 +113,7 @@ def test_mixin_sampling_timesteps(self): buffer.add(batch) buffer.add(batch) sample = buffer.sample(3) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 3.0, delta=0.2) @@ -125,7 +125,7 @@ def test_mixin_sampling_timesteps(self): for _ in range(100): buffer.add(batch) sample = buffer.sample(5) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 1.5, delta=0.2) @@ -142,7 +142,7 @@ def test_mixin_sampling_timesteps(self): for _ in range(100): buffer.add(batch) sample = buffer.sample(10) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 10.0, delta=0.2) @@ -156,12 +156,12 @@ def test_mixin_sampling_timesteps(self): buffer.add(batch) # Expect exactly 1 batch to be returned. sample = buffer.sample(1) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) self.assertTrue(len(sample) == 1) # Expect exactly 0 sample to be returned (nothing new to be returned; # no replay allowed (replay_ratio=0.0)). sample = buffer.sample(1) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) assert len(sample.policy_batches) == 0 # If we insert and replay n times, expect roughly return batches of # len 1 (replay_ratio=0.0 -> 0% replayed samples -> 1 new and 0 old samples @@ -170,7 +170,7 @@ def test_mixin_sampling_timesteps(self): for _ in range(100): buffer.add(batch) sample = buffer.sample(1) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 1.0, delta=0.2) @@ -187,11 +187,11 @@ def test_mixin_sampling_timesteps(self): buffer.add(batch) # Expect exactly 1 sample to be returned (the new batch). sample = buffer.sample(1) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) self.assertTrue(len(sample) == 1) # Another replay -> Expect exactly 1 sample to be returned. sample = buffer.sample(1) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) self.assertTrue(len(sample) == 1) # If we replay n times, expect roughly return batches of # len 1 (replay_ratio=1.0 -> 100% replayed samples -> 0 new and 1 old samples @@ -199,7 +199,7 @@ def test_mixin_sampling_timesteps(self): results = [] for _ in range(100): sample = buffer.sample(1) - assert type(sample) == MultiAgentBatch + assert isinstance(sample, MultiAgentBatch) results.append(len(sample.policy_batches[DEFAULT_POLICY_ID])) self.assertAlmostEqual(np.mean(results), 1.0) diff --git a/rllib/utils/replay_buffers/tests/test_multi_agent_prioritized_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_multi_agent_prioritized_replay_buffer.py index 9fd8f0043f4f..246cf7c3d055 100644 --- a/rllib/utils/replay_buffers/tests/test_multi_agent_prioritized_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_multi_agent_prioritized_replay_buffer.py @@ -151,7 +151,7 @@ def test_independent_mode(self): # Sample without specifying the policy should yield approx. the same # number of batches from each policy - num_sampled_dict = {_id: 0 for _id in range(num_policies)} + num_sampled_dict = dict.fromkeys(range(num_policies), 0) num_samples = 200 for i in range(num_samples): num_items = np.random.randint(1, 5) @@ -186,7 +186,7 @@ def test_update_priorities(self): # Fetch records, their indices and weights. mabatch = buffer.sample(3) - assert type(mabatch) == MultiAgentBatch + assert isinstance(mabatch, MultiAgentBatch) samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID] weights = samplebatch["weights"] @@ -211,9 +211,9 @@ def test_update_priorities(self): # (which still has a weight of 1.0). for _ in range(10): mabatch = buffer.sample(1000) - assert type(mabatch) == MultiAgentBatch + assert isinstance(mabatch, MultiAgentBatch) samplebatch = mabatch.policy_batches[DEFAULT_POLICY_ID] - assert type(mabatch) == MultiAgentBatch + assert isinstance(mabatch, MultiAgentBatch) indices = samplebatch["batch_indexes"] self.assertTrue(1900 < np.sum(indices) < 2200) # Test get_state/set_state. diff --git a/rllib/utils/replay_buffers/tests/test_multi_agent_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_multi_agent_replay_buffer.py index 910ab87dcfcc..c1f302c3f293 100644 --- a/rllib/utils/replay_buffers/tests/test_multi_agent_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_multi_agent_replay_buffer.py @@ -121,7 +121,7 @@ def test_lockstep_mode(self): self._add_sample_batch_to_buffer(buffer, batch_size=batch_size, num_batches=2) # Sampling from it now should yield our first batch 1/3 of the time - num_sampled_dict = {_id: 0 for _id in range(self.batch_id)} + num_sampled_dict = dict.fromkeys(range(self.batch_id), 0) num_samples = 200 for i in range(num_samples): _id = get_batch_id(buffer.sample(1)) @@ -162,7 +162,7 @@ def test_independent_mode_sequences_storage_unit(self): # Sampling from it now should yield each batch that went into a # multiagent batch 1/6th of the time - num_sampled_dict = {_id: 0 for _id in range(self.batch_id)} + num_sampled_dict = dict.fromkeys(range(self.batch_id), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) @@ -207,7 +207,7 @@ def test_independent_mode_multiple_policies(self): # Sample without specifying the policy should yield the same number # of batches from each policy - num_sampled_dict = {_id: 0 for _id in range(num_policies)} + num_sampled_dict = dict.fromkeys(range(num_policies), 0) num_samples = 200 for i in range(num_samples): num_items = np.random.randint(0, 5) diff --git a/rllib/utils/replay_buffers/tests/test_prioritized_episode_buffer.py b/rllib/utils/replay_buffers/tests/test_prioritized_episode_buffer.py index facc8dd5b199..c668cbf9a773 100644 --- a/rllib/utils/replay_buffers/tests/test_prioritized_episode_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_prioritized_episode_buffer.py @@ -286,12 +286,7 @@ def test_update_priorities(self): sample = buffer.sample(batch_size_B=16, n_step=1) index_counts.append( - any( - [ - idx in last_sampled_indices - for idx in buffer._last_sampled_indices - ] - ) + any(idx in last_sampled_indices for idx in buffer._last_sampled_indices) ) self.assertGreater(0.15, sum(index_counts) / len(index_counts)) diff --git a/rllib/utils/replay_buffers/tests/test_prioritized_replay_buffer_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_prioritized_replay_buffer_replay_buffer_api.py index 11d66f8bc2ed..eafc6c5f46c9 100644 --- a/rllib/utils/replay_buffers/tests/test_prioritized_replay_buffer_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_prioritized_replay_buffer_replay_buffer_api.py @@ -411,7 +411,7 @@ def test_sequences_unit(self): weight=1, ) - num_sampled_dict = {_id: 0 for _id in range(1, 5)} + num_sampled_dict = dict.fromkeys(range(1, 5), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1, beta=self.beta) @@ -451,7 +451,7 @@ def test_sequences_unit(self): assert buffer._next_idx == 1 assert buffer._eviction_started is True - num_sampled_dict = {_id: 0 for _id in range(1, 6)} + num_sampled_dict = dict.fromkeys(range(1, 6), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1, beta=self.beta) @@ -508,7 +508,7 @@ def test_episodes_unit(self): weight=1, ) - num_sampled_dict = {_id: 0 for _id in range(5)} + num_sampled_dict = dict.fromkeys(range(5), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1, beta=self.beta) @@ -541,7 +541,7 @@ def test_episodes_unit(self): weight=1, ) - num_sampled_dict = {_id: 0 for _id in range(7)} + num_sampled_dict = dict.fromkeys(range(7), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1, beta=self.beta) @@ -580,7 +580,7 @@ def test_episodes_unit(self): assert buffer._next_idx == 1 assert buffer._eviction_started is True - num_sampled_dict = {_id: 0 for _id in range(8)} + num_sampled_dict = dict.fromkeys(range(8), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1, beta=self.beta) diff --git a/rllib/utils/replay_buffers/tests/test_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_replay_buffer.py index 87ba177f81a0..68d7c3756d27 100644 --- a/rllib/utils/replay_buffers/tests/test_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_replay_buffer.py @@ -147,7 +147,7 @@ def test_timesteps_unit(self): self._add_data_to_buffer(buffer, batch_size=batch_size, num_batches=2) # Sampling from it now should yield our first batch 1/3 of the time - num_sampled_dict = {_id: 0 for _id in range(self.batch_id)} + num_sampled_dict = dict.fromkeys(range(self.batch_id), 0) num_samples = 200 for i in range(num_samples): _id = buffer.sample(1)["batch_id"][0] @@ -211,7 +211,7 @@ def test_sequences_unit(self): for batch in batches: buffer.add(batch) - num_sampled_dict = {_id: 0 for _id in range(1, 5)} + num_sampled_dict = dict.fromkeys(range(1, 5), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) @@ -252,7 +252,7 @@ def test_sequences_unit(self): # The first batch should now not be sampled anymore, other batches # should be sampled as before - num_sampled_dict = {_id: 0 for _id in range(2, 6)} + num_sampled_dict = dict.fromkeys(range(2, 6), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) @@ -302,7 +302,7 @@ def test_episodes_unit(self): for batch in batches: buffer.add(batch) - num_sampled_dict = {_id: 0 for _id in range(5)} + num_sampled_dict = dict.fromkeys(range(5), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) @@ -334,7 +334,7 @@ def test_episodes_unit(self): ) ) - num_sampled_dict = {_id: 0 for _id in range(7)} + num_sampled_dict = dict.fromkeys(range(7), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) @@ -372,7 +372,7 @@ def test_episodes_unit(self): assert buffer._next_idx == 1 assert buffer._eviction_started is True - num_sampled_dict = {_id: 0 for _id in range(8)} + num_sampled_dict = dict.fromkeys(range(8), 0) num_samples = 200 for i in range(num_samples): sample = buffer.sample(1) diff --git a/rllib/utils/tests/test_actor_manager.py b/rllib/utils/tests/test_actor_manager.py index 0d5720b886a9..c53026d1e611 100644 --- a/rllib/utils/tests/test_actor_manager.py +++ b/rllib/utils/tests/test_actor_manager.py @@ -237,7 +237,7 @@ def test_sync_call_not_ignore_error(self): wait_for_restore() # Some calls did error out. - self.assertTrue(any([not r.ok for r in results])) + self.assertTrue(any(not r.ok for r in results)) manager.clear() @@ -248,7 +248,7 @@ def test_sync_call_not_bringing_back_actors(self): results = manager.foreach_actor(lambda w: w.call()) # Some calls did error out. - self.assertTrue(any([not r.ok for r in results])) + self.assertTrue(any(not r.ok for r in results)) # Wait for actors to recover. wait_for_restore() diff --git a/rllib/utils/tests/test_utils.py b/rllib/utils/tests/test_utils.py index f79f5b91b7de..4bd6d833ba68 100644 --- a/rllib/utils/tests/test_utils.py +++ b/rllib/utils/tests/test_utils.py @@ -48,18 +48,14 @@ class TestUtils(unittest.TestCase): "c": {"ca": np.array([[[1, 2]], [[3, 5]]]), "cb": np.array([[1.0], [2.0]])}, } # Corresponding space struct. - spaces = dict( - { - "a": gym.spaces.Discrete(4), - "b": (gym.spaces.Box(-1.0, 10.0, (3,)), gym.spaces.Box(-1.0, 1.0, (3, 1))), - "c": dict( - { - "ca": gym.spaces.MultiDiscrete([4, 6]), - "cb": gym.spaces.Box(-1.0, 1.0, ()), - } - ), - } - ) + spaces = { + "a": gym.spaces.Discrete(4), + "b": (gym.spaces.Box(-1.0, 10.0, (3,)), gym.spaces.Box(-1.0, 1.0, (3, 1))), + "c": { + "ca": gym.spaces.MultiDiscrete([4, 6]), + "cb": gym.spaces.Box(-1.0, 1.0, ()), + }, + } @classmethod def setUpClass(cls) -> None: diff --git a/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/alpha_star.py b/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/alpha_star.py index 3850954e78e6..2a446d6e7188 100644 --- a/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/alpha_star.py +++ b/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/alpha_star.py @@ -2,6 +2,7 @@ A multi-agent, distributed multi-GPU, league-capable asynch. PPO ================================================================ """ + from typing import Any, Dict, Optional, Type, Union import gymnasium as gym @@ -446,7 +447,7 @@ def training_step(self) -> ResultDict: sample_results = self._sampling_actor_manager.get_ready() # Update sample counters. for sample_result in sample_results.values(): - for (env_steps, agent_steps) in sample_result: + for env_steps, agent_steps in sample_result: self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps diff --git a/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/league_builder.py b/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/league_builder.py index 5c7125c856bf..0cadb6639817 100644 --- a/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/league_builder.py +++ b/rllib_contrib/alpha_star/src/rllib_alpha_star/alpha_star/league_builder.py @@ -1,6 +1,6 @@ import logging import re -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from collections import defaultdict from typing import Any, DefaultDict, Dict @@ -33,6 +33,7 @@ def __init__(self, algo: Algorithm, algo_config: AlgorithmConfig): self.algo = algo self.config = algo_config + @abstractmethod def build_league(self, result: ResultDict) -> None: """Method containing league-building logic. Called after train step. @@ -170,10 +171,8 @@ def __init__( policies_to_train.append(pid) # Build initial policy mapping function: main_0 vs main_exploiter_0. - self.config.policy_mapping_fn = ( - lambda agent_id, episode, worker, **kw: "main_0" - if episode.episode_id % 2 == agent_id - else "main_exploiter_0" + self.config.policy_mapping_fn = lambda agent_id, episode, worker, **kw: ( + "main_0" if episode.episode_id % 2 == agent_id else "main_exploiter_0" ) self.config.policies = policies self.config.policies_to_train = policies_to_train diff --git a/rllib_contrib/apex_dqn/src/rllib_apex_dqn/apex_dqn/apex_dqn.py b/rllib_contrib/apex_dqn/src/rllib_apex_dqn/apex_dqn/apex_dqn.py index 868292f81be0..aa2ae2502a28 100644 --- a/rllib_contrib/apex_dqn/src/rllib_apex_dqn/apex_dqn/apex_dqn.py +++ b/rllib_contrib/apex_dqn/src/rllib_apex_dqn/apex_dqn/apex_dqn.py @@ -10,7 +10,8 @@ Detailed documentation: https://docs.ray.io/en/master/rllib-algorithms.html#distributed-prioritized-experience-replay-ape-x -""" # noqa: E501 +""" + import copy import platform import random @@ -396,9 +397,11 @@ def training_step(self) -> ResultDict: # Update target network every `target_network_update_freq` sample steps. cur_ts = self._counters[ - NUM_AGENT_STEPS_SAMPLED - if self.config.count_steps_by == "agent_steps" - else NUM_ENV_STEPS_SAMPLED + ( + NUM_AGENT_STEPS_SAMPLED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_SAMPLED + ) ] if cur_ts > self.config.num_steps_sampled_before_learning_starts: @@ -500,9 +503,11 @@ def update_workers(self, _num_samples_ready: Dict[ActorHandle, int]) -> int: with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: curr_weights = self.curr_learner_weights timestep = self._counters[ - NUM_AGENT_STEPS_TRAINED - if self.config.count_steps_by == "agent_steps" - else NUM_ENV_STEPS_TRAINED + ( + NUM_AGENT_STEPS_TRAINED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_TRAINED + ) ] for ( remote_sampler_worker_id, @@ -646,9 +651,11 @@ def update_target_networks(self, num_new_trained_samples) -> None: ) self._counters[NUM_TARGET_UPDATES] += 1 self._counters[LAST_TARGET_UPDATE_TS] = self._counters[ - NUM_AGENT_STEPS_TRAINED - if self.config.count_steps_by == "agent_steps" - else NUM_ENV_STEPS_TRAINED + ( + NUM_AGENT_STEPS_TRAINED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_TRAINED + ) ] def _get_shard0_replay_stats(self) -> Dict[str, Any]: diff --git a/rllib_contrib/bandit/src/rllib_bandit/env/bandit_envs_discrete.py b/rllib_contrib/bandit/src/rllib_bandit/env/bandit_envs_discrete.py index 1ad32366c6a0..db57eb6b1c9f 100644 --- a/rllib_contrib/bandit/src/rllib_bandit/env/bandit_envs_discrete.py +++ b/rllib_contrib/bandit/src/rllib_bandit/env/bandit_envs_discrete.py @@ -56,7 +56,7 @@ class LinearDiscreteEnv(gym.Env): def __init__(self, config=None): self.config = copy.copy(self.DEFAULT_CONFIG_LINEAR) - if config is not None and type(config) == dict: + if config is not None and isinstance(config, dict): self.config.update(config) self.feature_dim = self.config["feature_dim"] @@ -128,7 +128,7 @@ class WheelBanditEnv(gym.Env): def __init__(self, config=None): self.config = copy.copy(self.DEFAULT_CONFIG_WHEEL) - if config is not None and type(config) == dict: + if config is not None and isinstance(config, dict): self.config.update(config) self.delta = self.config["delta"] diff --git a/rllib_contrib/dt/tests/test_segmentation_buffer.py b/rllib_contrib/dt/tests/test_segmentation_buffer.py index c9036e28ff9a..c16a03bb45cb 100644 --- a/rllib_contrib/dt/tests/test_segmentation_buffer.py +++ b/rllib_contrib/dt/tests/test_segmentation_buffer.py @@ -89,9 +89,9 @@ def _get_internal_buffer( """Get the internal buffer list from the buffer. If MultiAgent then return the internal buffer corresponding to the given policy_id. """ - if type(buffer) == SegmentationBuffer: + if isinstance(buffer, SegmentationBuffer): return buffer._buffer - elif type(buffer) == MultiAgentSegmentationBuffer: + elif isinstance(buffer, MultiAgentSegmentationBuffer): return buffer.buffers[policy_id]._buffer else: raise NotImplementedError @@ -104,9 +104,9 @@ def _as_sample_batch( """Returns a SampleBatch. If MultiAgentBatch then return the SampleBatch corresponding to the given policy_id. """ - if type(batch) == SampleBatch: + if isinstance(batch, SampleBatch): return batch - elif type(batch) == MultiAgentBatch: + elif isinstance(batch, MultiAgentBatch): return batch.policy_batches[policy_id] else: raise NotImplementedError diff --git a/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/leela_chess_zero_model.py b/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/leela_chess_zero_model.py index bfe8c7e29392..b05b43606d03 100644 --- a/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/leela_chess_zero_model.py +++ b/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/leela_chess_zero_model.py @@ -93,7 +93,7 @@ def forward(self, input_dict, state, seq_lens): print(input_dict) raise Exception("No observation in input_dict") if self.alpha_zero_obs: - if not type(obs) == torch.Tensor: + if not isinstance(obs, torch.Tensor): obs = torch.from_numpy(obs.astype(np.float32)) action_mask = torch.from_numpy(action_mask.astype(np.float32)) try: diff --git a/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/mcts.py b/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/mcts.py index b0151f02a82d..46cfa58a5f97 100644 --- a/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/mcts.py +++ b/rllib_contrib/leela_chess_zero/src/rllib_leela_chess_zero/leela_chess_zero/mcts.py @@ -41,11 +41,11 @@ def __init__( multi_agent = True if multi_agent: current_agent = self.state.agent_selection - if type(self.reward) == dict: + if isinstance(self.reward, dict): self.reward = self.reward[current_agent] - if type(self.done) == dict: + if isinstance(self.done, dict): self.done = self.done[current_agent] - if type(self.obs) == dict: + if isinstance(self.obs, dict): self.valid_actions = obs[current_agent]["action_mask"].astype(bool) self.obs = obs[current_agent] else: