diff --git a/pyproject.toml b/pyproject.toml index f335bb349442..c8bfb82794d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,22 +65,6 @@ afterray = ["psutil", "setproctitle"] "python/ray/__init__.py" = ["I"] "python/ray/dag/__init__.py" = ["I"] "python/ray/air/__init__.py" = ["I"] -# "rllib/__init__.py" = ["I"] -# "rllib/benchmarks/*" = ["I"] -# "rllib/connectors/*" = ["I"] -# "rllib/evaluation/*" = ["I"] -# "rllib/models/*" = ["I"] -"rllib/utils/*" = ["I"] -"rllib/algorithms/*" = ["I"] -# "rllib/core/*" = ["I"] -# "rllib/examples/*" = ["I"] -# "rllib/offline/*" = ["I"] -# "rllib/tests/*" = ["I"] -# "rllib/callbacks/*" = ["I"] -# "rllib/env/*" = ["I"] -# "rllib/execution/*" = ["I"] -# "rllib/policy/*" = ["I"] -# "rllib/tuned_examples/*" = ["I"] "release/*" = ["I"] # TODO(matthewdeng): Remove this line diff --git a/rllib/algorithms/__init__.py b/rllib/algorithms/__init__.py index fdc21775e119..f7e0696a0d32 100644 --- a/rllib/algorithms/__init__.py +++ b/rllib/algorithms/__init__.py @@ -6,15 +6,14 @@ from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig from ray.rllib.algorithms.impala.impala import ( IMPALA, - IMPALAConfig, Impala, + IMPALAConfig, ImpalaConfig, ) from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig from ray.rllib.algorithms.sac.sac import SAC, SACConfig - __all__ = [ "Algorithm", "AlgorithmConfig", diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 701d21bd87c7..3583df0d31e4 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -13,6 +13,7 @@ from collections import defaultdict from datetime import datetime from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -23,7 +24,6 @@ Set, Tuple, Type, - TYPE_CHECKING, Union, ) @@ -47,9 +47,9 @@ from ray.rllib.algorithms.utils import ( AggregatorActor, _get_env_runner_bundles, - _get_offline_eval_runner_bundles, _get_learner_bundles, _get_main_process_bundle, + _get_offline_eval_runner_bundles, ) from ray.rllib.callbacks.utils import make_callback from ray.rllib.connectors.agent.obs_preproc import ObsPreprocessorConnector @@ -84,30 +84,30 @@ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.offline import get_dataset_and_shards from ray.rllib.offline.estimators import ( - OffPolicyEstimator, - ImportanceSampling, - WeightedImportanceSampling, DirectMethod, DoublyRobust, + ImportanceSampling, + OffPolicyEstimator, + WeightedImportanceSampling, ) from ray.rllib.offline.offline_evaluator import OfflineEvaluator from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch -from ray.rllib.utils import deep_update, FilterManager, force_list +from ray.rllib.utils import FilterManager, deep_update, force_list from ray.rllib.utils.actor_manager import FaultTolerantActorManager from ray.rllib.utils.annotations import ( DeveloperAPI, ExperimentalAPI, OldAPIStack, - override, OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, PublicAPI, + override, ) from ray.rllib.utils.checkpoints import ( - Checkpointable, CHECKPOINT_VERSION, CHECKPOINT_VERSION_LEARNER_AND_ENV_RUNNER, + Checkpointable, get_checkpoint_info, try_import_msgpack, ) @@ -134,9 +134,9 @@ NUM_AGENT_STEPS_TRAINED, NUM_AGENT_STEPS_TRAINED_LIFETIME, NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_ENV_STEPS_SAMPLED_THIS_ITER, - NUM_ENV_STEPS_SAMPLED_FOR_EVALUATION_THIS_ITER, NUM_ENV_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_EPISODES, @@ -147,13 +147,13 @@ RESTORE_ENV_RUNNERS_TIMER, RESTORE_EVAL_ENV_RUNNERS_TIMER, RESTORE_OFFLINE_EVAL_RUNNERS_TIMER, + STEPS_TRAINED_THIS_ITER_COUNTER, SYNCH_ENV_CONNECTOR_STATES_TIMER, SYNCH_EVAL_ENV_CONNECTOR_STATES_TIMER, SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, TRAINING_ITERATION_TIMER, TRAINING_STEP_TIMER, - STEPS_TRAINED_THIS_ITER_COUNTER, ) from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.metrics.metrics_logger import MetricsLogger @@ -164,7 +164,7 @@ ) from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer, ReplayBuffer from ray.rllib.utils.runners.runner_group import RunnerGroup -from ray.rllib.utils.serialization import deserialize_type, NOT_SERIALIZABLE +from ray.rllib.utils.serialization import NOT_SERIALIZABLE, deserialize_type from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import ( AgentConnectorDataType, @@ -191,8 +191,7 @@ from ray.tune.execution.placement_groups import PlacementGroupFactory from ray.tune.experiment.trial import ExportFormat from ray.tune.logger import Logger, UnifiedLogger -from ray.tune.registry import ENV_CREATOR, _global_registry -from ray.tune.registry import get_trainable_cls +from ray.tune.registry import ENV_CREATOR, _global_registry, get_trainable_cls from ray.tune.resources import Resources from ray.tune.result import TRAINING_ITERATION from ray.tune.trainable import Trainable diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 9d6172b4c94e..3953c2c4dfe7 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -1,10 +1,11 @@ import copy import dataclasses -from enum import Enum import logging import math import sys +from enum import Enum from typing import ( + TYPE_CHECKING, Any, Callable, Collection, @@ -13,16 +14,20 @@ Optional, Tuple, Type, - TYPE_CHECKING, Union, ) -from typing_extensions import Self import gymnasium as gym import tree from packaging import version +from typing_extensions import Self import ray +from ray._common.deprecation import ( + DEPRECATED_VALUE, + Deprecated, + deprecation_warning, +) from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core import DEFAULT_MODULE_ID @@ -34,7 +39,7 @@ from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec -from ray.rllib.env import INPUT_ENV_SPACES, INPUT_ENV_SINGLE_SPACES +from ray.rllib.env import INPUT_ENV_SINGLE_SPACES, INPUT_ENV_SPACES from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.wrappers.atari_wrappers import is_atari from ray.rllib.evaluation.collectors.sample_collector import SampleCollector @@ -49,11 +54,6 @@ OldAPIStack, OverrideToImplementCustomLogic_CallToSuperRecommended, ) -from ray._common.deprecation import ( - DEPRECATED_VALUE, - Deprecated, - deprecation_warning, -) from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import NotProvided, from_config from ray.rllib.utils.schedules.scheduler import Scheduler @@ -84,7 +84,6 @@ from ray.util import log_once from ray.util.placement_group import PlacementGroup - if TYPE_CHECKING: from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.core.learner import Learner diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index d536c2f9498d..d023913ac002 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -10,23 +10,24 @@ https://arxiv.org/pdf/1912.00167 """ +import logging from typing import Optional, Type + from typing_extensions import Self -import logging +from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.impala.impala import IMPALA, IMPALAConfig from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override -from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, + LEARNER_STATS_KEY, NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_TARGET_UPDATES, ) -from ray.rllib.utils.metrics import LEARNER_STATS_KEY logger = logging.getLogger(__name__) diff --git a/rllib/algorithms/appo/appo_tf_policy.py b/rllib/algorithms/appo/appo_tf_policy.py index 4af36f099df9..eab4bfefeb2e 100644 --- a/rllib/algorithms/appo/appo_tf_policy.py +++ b/rllib/algorithms/appo/appo_tf_policy.py @@ -5,37 +5,37 @@ Keep in sync with changes to VTraceTFPolicy. """ -import numpy as np import logging -import gymnasium as gym from typing import Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + from ray.rllib.algorithms.appo.utils import make_appo_models from ray.rllib.algorithms.impala import vtrace_tf as vtrace from ray.rllib.algorithms.impala.impala_tf_policy import ( - _make_time_major, VTraceClipGradients, VTraceOptimizer, + _make_time_major, ) from ray.rllib.evaluation.postprocessing import ( + Postprocessing, compute_bootstrap_value, compute_gae_for_sample_batch, - Postprocessing, ) -from ray.rllib.models.tf.tf_action_dist import Categorical -from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.tf_mixins import ( EntropyCoeffSchedule, - LearningRateSchedule, - KLCoeffMixin, - ValueNetworkMixin, GradStatsMixin, + KLCoeffMixin, + LearningRateSchedule, TargetNetworkMixin, + ValueNetworkMixin, ) -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.utils.annotations import ( override, ) diff --git a/rllib/algorithms/appo/appo_torch_policy.py b/rllib/algorithms/appo/appo_torch_policy.py index 1d28138c8c25..f150c6761cac 100644 --- a/rllib/algorithms/appo/appo_torch_policy.py +++ b/rllib/algorithms/appo/appo_torch_policy.py @@ -5,37 +5,38 @@ Keep in sync with changes to VTraceTFPolicy. """ -import gymnasium as gym -import numpy as np import logging from typing import Any, Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + import ray -from ray.rllib.algorithms.appo.utils import make_appo_models import ray.rllib.algorithms.impala.vtrace_torch as vtrace +from ray.rllib.algorithms.appo.utils import make_appo_models from ray.rllib.algorithms.impala.impala_torch_policy import ( - make_time_major, VTraceOptimizer, + make_time_major, ) from ray.rllib.evaluation.postprocessing import ( + Postprocessing, compute_bootstrap_value, compute_gae_for_sample_batch, - Postprocessing, ) from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( - TorchDistributionWrapper, TorchCategorical, + TorchDistributionWrapper, ) from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_mixins import ( EntropyCoeffSchedule, - LearningRateSchedule, KLCoeffMixin, - ValueNetworkMixin, + LearningRateSchedule, TargetNetworkMixin, + ValueNetworkMixin, ) from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override diff --git a/rllib/algorithms/appo/default_appo_rl_module.py b/rllib/algorithms/appo/default_appo_rl_module.py index 9152ac43d9d0..e6eb13d23bf1 100644 --- a/rllib/algorithms/appo/default_appo_rl_module.py +++ b/rllib/algorithms/appo/default_appo_rl_module.py @@ -8,12 +8,11 @@ TARGET_NETWORK_ACTION_DIST_INPUTS, TargetNetworkAPI, ) -from ray.rllib.utils.typing import NetworkType - from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) +from ray.rllib.utils.typing import NetworkType from ray.util.annotations import DeveloperAPI diff --git a/rllib/algorithms/appo/tests/test_appo.py b/rllib/algorithms/appo/tests/test_appo.py index 6986eb1d2146..d6271f575104 100644 --- a/rllib/algorithms/appo/tests/test_appo.py +++ b/rllib/algorithms/appo/tests/test_appo.py @@ -11,9 +11,9 @@ NUM_ENV_STEPS_SAMPLED_LIFETIME, ) from ray.rllib.utils.test_utils import ( + check_compute_single_action, check_train_results, check_train_results_new_api_stack, - check_compute_single_action, ) diff --git a/rllib/algorithms/appo/tests/test_appo_learner.py b/rllib/algorithms/appo/tests/test_appo_learner.py index bd8cbffc10eb..92f1df9f8608 100644 --- a/rllib/algorithms/appo/tests/test_appo_learner.py +++ b/rllib/algorithms/appo/tests/test_appo_learner.py @@ -1,6 +1,6 @@ import unittest -import numpy as np +import numpy as np import tree # pip install dm_tree import ray @@ -13,7 +13,6 @@ from ray.rllib.utils.metrics import LEARNER_RESULTS from ray.rllib.utils.torch_utils import convert_to_torch_tensor - frag_length = 50 FAKE_BATCH = { @@ -119,7 +118,8 @@ def test_kl_coeff_changes(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/appo/torch/appo_torch_learner.py b/rllib/algorithms/appo/torch/appo_torch_learner.py index 62a4198952ec..9e3bbfca3b92 100644 --- a/rllib/algorithms/appo/torch/appo_torch_learner.py +++ b/rllib/algorithms/appo/torch/appo_torch_learner.py @@ -12,9 +12,9 @@ from typing import Dict from ray.rllib.algorithms.appo.appo import ( - APPOConfig, LEARNER_RESULTS_CURR_KL_COEFF_KEY, LEARNER_RESULTS_KL_KEY, + APPOConfig, ) from ray.rllib.algorithms.appo.appo_learner import APPOLearner from ray.rllib.algorithms.impala.torch.impala_torch_learner import IMPALATorchLearner @@ -23,7 +23,7 @@ vtrace_torch, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.learner import ENTROPY_KEY, POLICY_LOSS_KEY, VF_LOSS_KEY from ray.rllib.core.rl_module.apis import ( TARGET_NETWORK_ACTION_DIST_INPUTS, TargetNetworkAPI, diff --git a/rllib/algorithms/bc/__init__.py b/rllib/algorithms/bc/__init__.py index 0bf454356c60..ac3749f7a57f 100644 --- a/rllib/algorithms/bc/__init__.py +++ b/rllib/algorithms/bc/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.algorithms.bc.bc import BCConfig, BC +from ray.rllib.algorithms.bc.bc import BC, BCConfig __all__ = [ "BC", diff --git a/rllib/algorithms/bc/bc_catalog.py b/rllib/algorithms/bc/bc_catalog.py index 1ac0e935266b..54a01ddd649c 100644 --- a/rllib/algorithms/bc/bc_catalog.py +++ b/rllib/algorithms/bc/bc_catalog.py @@ -2,9 +2,9 @@ import gymnasium as gym from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.models.base import Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import FreeLogStdMLPHeadConfig, MLPHeadConfig -from ray.rllib.core.models.base import Model from ray.rllib.utils.annotations import OverrideToImplementCustomLogic diff --git a/rllib/algorithms/bc/tests/test_bc.py b/rllib/algorithms/bc/tests/test_bc.py index d3bbf371dad2..edec3c3422ed 100644 --- a/rllib/algorithms/bc/tests/test_bc.py +++ b/rllib/algorithms/bc/tests/test_bc.py @@ -1,7 +1,7 @@ -from pathlib import Path import unittest -import ray +from pathlib import Path +import ray from ray.rllib.algorithms.bc import BCConfig from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, @@ -88,7 +88,8 @@ def test_bc_compilation_and_learning_from_offline_file(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/callbacks.py b/rllib/algorithms/callbacks.py index 49e59d0c6a3e..9330e66335d7 100644 --- a/rllib/algorithms/callbacks.py +++ b/rllib/algorithms/callbacks.py @@ -2,7 +2,6 @@ from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.callbacks.utils import _make_multi_callbacks - # Backward compatibility DefaultCallbacks = RLlibCallback make_multi_callbacks = _make_multi_callbacks diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 60110cb6d71a..681f5210c6dc 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -1,7 +1,12 @@ import logging from typing import Optional, Type, Union + from typing_extensions import Self +from ray._common.deprecation import ( + DEPRECATED_VALUE, + deprecation_warning, +) from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy @@ -26,24 +31,20 @@ ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import ( - DEPRECATED_VALUE, - deprecation_warning, -) from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.metrics import ( + LAST_TARGET_UPDATE_TS, LEARNER_RESULTS, LEARNER_UPDATE_TIMER, - LAST_TARGET_UPDATE_TS, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_TRAINED, NUM_TARGET_UPDATES, OFFLINE_SAMPLING_TIMER, - TARGET_NET_UPDATE_TIMER, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, + TARGET_NET_UPDATE_TIMER, TIMERS, ) from ray.rllib.utils.typing import ResultDict, RLModuleSpecType diff --git a/rllib/algorithms/cql/cql_tf_policy.py b/rllib/algorithms/cql/cql_tf_policy.py index 0bfc871f328d..ae6c4f8d4fef 100644 --- a/rllib/algorithms/cql/cql_tf_policy.py +++ b/rllib/algorithms/cql/cql_tf_policy.py @@ -1,40 +1,41 @@ """ TensorFlow policy class used for CQL. """ +import logging from functools import partial -import numpy as np +from typing import Dict, List, Type, Union + import gymnasium as gym -import logging +import numpy as np import tree -from typing import Dict, List, Type, Union import ray from ray.rllib.algorithms.sac.sac_tf_policy import ( + ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, + ComputeTDErrorMixin, + _get_dist_class, apply_gradients as sac_apply_gradients, + build_sac_model, compute_and_clip_gradients as sac_compute_and_clip_gradients, get_distribution_inputs_and_class, - _get_dist_class, - build_sac_model, postprocess_trajectory, setup_late_mixins, stats, validate_spaces, - ActorCriticOptimizerMixin as SACActorCriticOptimizerMixin, - ComputeTDErrorMixin, ) from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution -from ray.rllib.policy.tf_mixins import TargetNetworkMixin -from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import TargetNetworkMixin +from ray.rllib.policy.tf_policy_template import build_tf_policy from ray.rllib.utils.exploration.random import Random from ray.rllib.utils.framework import get_variable, try_import_tf, try_import_tfp from ray.rllib.utils.typing import ( + AlgorithmConfigDict, LocalOptimizer, ModelGradients, TensorType, - AlgorithmConfigDict, ) tf1, tf, tfv = try_import_tf() diff --git a/rllib/algorithms/cql/cql_torch_policy.py b/rllib/algorithms/cql/cql_torch_policy.py index 2f67c8d642bb..a7fab43bda61 100644 --- a/rllib/algorithms/cql/cql_torch_policy.py +++ b/rllib/algorithms/cql/cql_torch_policy.py @@ -1,40 +1,41 @@ """ PyTorch policy class used for CQL. """ -import numpy as np -import gymnasium as gym import logging -import tree from typing import Dict, List, Tuple, Type, Union +import gymnasium as gym +import numpy as np +import tree + import ray from ray.rllib.algorithms.sac.sac_tf_policy import ( postprocess_trajectory, validate_spaces, ) from ray.rllib.algorithms.sac.sac_torch_policy import ( + ComputeTDErrorMixin, _get_dist_class, - stats, + action_distribution_fn, build_sac_model_and_action_dist, optimizer_fn, - ComputeTDErrorMixin, setup_late_mixins, - action_distribution_fn, + stats, ) -from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.policy.policy_template import build_policy_class +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper from ray.rllib.policy.policy import Policy -from ray.rllib.policy.torch_mixins import TargetNetworkMixin +from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY -from ray.rllib.utils.typing import LocalOptimizer, TensorType, AlgorithmConfigDict from ray.rllib.utils.torch_utils import ( apply_grad_clipping, - convert_to_torch_tensor, concat_multi_gpu_td_errors, + convert_to_torch_tensor, ) +from ray.rllib.utils.typing import AlgorithmConfigDict, LocalOptimizer, TensorType torch, nn = try_import_torch() F = nn.functional diff --git a/rllib/algorithms/cql/tests/test_cql_old_api_stack.py b/rllib/algorithms/cql/tests/test_cql_old_api_stack.py index 1321741253a8..c2d3686da71c 100644 --- a/rllib/algorithms/cql/tests/test_cql_old_api_stack.py +++ b/rllib/algorithms/cql/tests/test_cql_old_api_stack.py @@ -1,6 +1,6 @@ -from pathlib import Path import os import unittest +from pathlib import Path import ray from ray.rllib.algorithms import cql @@ -121,7 +121,8 @@ def test_cql_compilation(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/cql/torch/cql_torch_learner.py b/rllib/algorithms/cql/torch/cql_torch_learner.py index e9f6897d3c83..4c04fb5de873 100644 --- a/rllib/algorithms/cql/torch/cql_torch_learner.py +++ b/rllib/algorithms/cql/torch/cql_torch_learner.py @@ -1,27 +1,27 @@ from typing import Dict -from ray.tune.result import TRAINING_ITERATION +from ray.rllib.algorithms.cql.cql import CQLConfig from ray.rllib.algorithms.sac.sac_learner import ( LOGPS_KEY, QF_LOSS_KEY, - QF_MEAN_KEY, QF_MAX_KEY, + QF_MEAN_KEY, QF_MIN_KEY, QF_PREDS, QF_TWIN_LOSS_KEY, QF_TWIN_PREDS, TD_ERROR_MEAN_KEY, ) -from ray.rllib.algorithms.cql.cql import CQLConfig from ray.rllib.algorithms.sac.torch.sac_torch_learner import SACTorchLearner from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( POLICY_LOSS_KEY, ) from ray.rllib.utils.annotations import override -from ray.rllib.utils.metrics import ALL_MODULES from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import ALL_MODULES from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType +from ray.tune.result import TRAINING_ITERATION torch, nn = try_import_torch() diff --git a/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py b/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py index 32e90815710e..1c2e7a7a2301 100644 --- a/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py +++ b/rllib/algorithms/cql/torch/default_cql_torch_rl_module.py @@ -1,11 +1,12 @@ -import tree from typing import Any, Dict, Optional +import tree + +from ray.rllib.algorithms.sac.sac_catalog import SACCatalog from ray.rllib.algorithms.sac.sac_learner import ( QF_PREDS, QF_TWIN_PREDS, ) -from ray.rllib.algorithms.sac.sac_catalog import SACCatalog from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import ( DefaultSACTorchRLModule, ) diff --git a/rllib/algorithms/dqn/default_dqn_rl_module.py b/rllib/algorithms/dqn/default_dqn_rl_module.py index 78f9fe2c2e60..b4062ead7adf 100644 --- a/rllib/algorithms/dqn/default_dqn_rl_module.py +++ b/rllib/algorithms/dqn/default_dqn_rl_module.py @@ -3,17 +3,16 @@ from ray.rllib.core.learner.utils import make_target_network from ray.rllib.core.models.base import Encoder, Model -from ray.rllib.core.rl_module.apis import QNetAPI, InferenceOnlyAPI, TargetNetworkAPI +from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, + override, ) from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import NetworkType, TensorType from ray.util.annotations import DeveloperAPI - QF_PREDS = "qf_preds" ATOMS = "atoms" QF_LOGITS = "qf_logits" diff --git a/rllib/algorithms/dqn/distributional_q_tf_model.py b/rllib/algorithms/dqn/distributional_q_tf_model.py index a4dd63f587b7..421f5716d2b7 100644 --- a/rllib/algorithms/dqn/distributional_q_tf_model.py +++ b/rllib/algorithms/dqn/distributional_q_tf_model.py @@ -3,6 +3,7 @@ from typing import List import gymnasium as gym + from ray.rllib.models.tf.layers import NoisyLayer from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.annotations import OldAPIStack diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 02c322b48eee..a5ca9a754d68 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -9,12 +9,14 @@ https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn """ # noqa: E501 -from collections import defaultdict import logging +from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union -from typing_extensions import Self + import numpy as np +from typing_extensions import Self +from ray._common.deprecation import DEPRECATED_VALUE from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy @@ -24,21 +26,14 @@ from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) -from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.execution.train_ops import ( - train_one_step, multi_gpu_train_one_step, + train_one_step, ) from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override -from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.replay_buffers.utils import ( - update_priorities_in_episode_replay_buffer, - update_priorities_in_replay_buffer, - validate_buffer_config, -) -from ray.rllib.utils.typing import ResultDict from ray.rllib.utils.metrics import ( ALL_MODULES, ENV_RUNNER_RESULTS, @@ -60,10 +55,16 @@ TD_ERROR_KEY, TIMERS, ) -from ray._common.deprecation import DEPRECATED_VALUE -from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.replay_buffers.utils import ( + sample_min_n_steps_from_buffer, + update_priorities_in_episode_replay_buffer, + update_priorities_in_replay_buffer, + validate_buffer_config, +) from ray.rllib.utils.typing import ( LearningRateOrSchedule, + ResultDict, RLModuleSpecType, SampleBatchType, ) diff --git a/rllib/algorithms/dqn/dqn_catalog.py b/rllib/algorithms/dqn/dqn_catalog.py index f98dc5429c3a..32c7cf1c063f 100644 --- a/rllib/algorithms/dqn/dqn_catalog.py +++ b/rllib/algorithms/dqn/dqn_catalog.py @@ -1,13 +1,13 @@ import gymnasium as gym -from ray.rllib.core.models.catalog import Catalog +from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.core.models.base import Model +from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import MLPHeadConfig -from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.utils.annotations import ( ExperimentalAPI, - override, OverrideToImplementCustomLogic, + override, ) diff --git a/rllib/algorithms/dqn/dqn_learner.py b/rllib/algorithms/dqn/dqn_learner.py index b55385eaf939..64bc51969a75 100644 --- a/rllib/algorithms/dqn/dqn_learner.py +++ b/rllib/algorithms/dqn/dqn_learner.py @@ -12,8 +12,8 @@ from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.metrics import ( LAST_TARGET_UPDATE_TS, @@ -22,7 +22,6 @@ ) from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn - # Now, this is double defined: In `SACRLModule` and here. I would keep it here # or push it into the `Learner` as these are recurring keys in RL. ATOMS = "atoms" diff --git a/rllib/algorithms/dqn/dqn_torch_model.py b/rllib/algorithms/dqn/dqn_torch_model.py index 03c109878f73..4cb93bb63967 100644 --- a/rllib/algorithms/dqn/dqn_torch_model.py +++ b/rllib/algorithms/dqn/dqn_torch_model.py @@ -1,7 +1,9 @@ """PyTorch model for DQN""" from typing import Sequence + import gymnasium as gym + from ray.rllib.models.torch.misc import SlimFC from ray.rllib.models.torch.modules.noisy_layer import NoisyLayer from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 diff --git a/rllib/algorithms/dqn/dqn_torch_policy.py b/rllib/algorithms/dqn/dqn_torch_policy.py index 3229e379c730..fead64a5bc11 100644 --- a/rllib/algorithms/dqn/dqn_torch_policy.py +++ b/rllib/algorithms/dqn/dqn_torch_policy.py @@ -3,6 +3,7 @@ from typing import Dict, List, Tuple import gymnasium as gym + import ray from ray.rllib.algorithms.dqn.dqn_tf_policy import ( PRIO_WEIGHTS, @@ -14,8 +15,8 @@ from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( - get_torch_categorical_class_with_temperature, TorchDistributionWrapper, + get_torch_categorical_class_with_temperature, ) from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class @@ -29,15 +30,15 @@ from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_utils import ( + FLOAT_MIN, apply_grad_clipping, concat_multi_gpu_td_errors, - FLOAT_MIN, huber_loss, l2_loss, reduce_mean_ignore_inf, softmax_cross_entropy_with_logits, ) -from ray.rllib.utils.typing import TensorType, AlgorithmConfigDict +from ray.rllib.utils.typing import AlgorithmConfigDict, TensorType torch, nn = try_import_torch() F = None diff --git a/rllib/algorithms/dqn/tests/test_dqn.py b/rllib/algorithms/dqn/tests/test_dqn.py index 238daefdb2f5..9805ce181d04 100644 --- a/rllib/algorithms/dqn/tests/test_dqn.py +++ b/rllib/algorithms/dqn/tests/test_dqn.py @@ -47,7 +47,8 @@ def test_dqn_compilation(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py b/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py index 968ebe2da68d..b1a07226d5c7 100644 --- a/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py +++ b/rllib/algorithms/dqn/torch/default_dqn_torch_rl_module.py @@ -1,8 +1,8 @@ -import tree from typing import Dict, Union +import tree + from ray.rllib.algorithms.dqn.default_dqn_rl_module import ( - DefaultDQNRLModule, ATOMS, QF_LOGITS, QF_NEXT_PREDS, @@ -10,16 +10,17 @@ QF_PROBS, QF_TARGET_NEXT_PREDS, QF_TARGET_NEXT_PROBS, + DefaultDQNRLModule, ) from ray.rllib.algorithms.dqn.dqn_catalog import DQNCatalog from ray.rllib.core.columns import Columns -from ray.rllib.core.models.base import Encoder, ENCODER_OUT, Model +from ray.rllib.core.models.base import ENCODER_OUT, Encoder, Model from ray.rllib.core.rl_module.apis.q_net_api import QNetAPI -from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.core.rl_module.rl_module import RLModule +from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.typing import TensorType, TensorStructType +from ray.rllib.utils.typing import TensorStructType, TensorType from ray.util.annotations import DeveloperAPI torch, nn = try_import_torch() diff --git a/rllib/algorithms/dqn/torch/dqn_torch_learner.py b/rllib/algorithms/dqn/torch/dqn_torch_learner.py index 4289bcc7cdcf..3e77529bc130 100644 --- a/rllib/algorithms/dqn/torch/dqn_torch_learner.py +++ b/rllib/algorithms/dqn/torch/dqn_torch_learner.py @@ -3,18 +3,18 @@ from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.dqn.dqn_learner import ( ATOMS, - DQNLearner, - QF_LOSS_KEY, QF_LOGITS, - QF_MEAN_KEY, + QF_LOSS_KEY, QF_MAX_KEY, + QF_MEAN_KEY, QF_MIN_KEY, QF_NEXT_PREDS, - QF_TARGET_NEXT_PREDS, - QF_TARGET_NEXT_PROBS, QF_PREDS, QF_PROBS, + QF_TARGET_NEXT_PREDS, + QF_TARGET_NEXT_PROBS, TD_ERROR_MEAN_KEY, + DQNLearner, ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.torch.torch_learner import TorchLearner @@ -23,7 +23,6 @@ from ray.rllib.utils.metrics import TD_ERROR_KEY from ray.rllib.utils.typing import ModuleID, TensorType - torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index ade67045c977..935f7a53a738 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -10,9 +10,9 @@ import logging from typing import Any, Dict, Optional, Union -from typing_extensions import Self import gymnasium as gym +from typing_extensions import Self from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided @@ -33,8 +33,7 @@ from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import deep_update -from ray.rllib.utils.annotations import override, PublicAPI -from ray.rllib.utils.numpy import one_hot +from ray.rllib.utils.annotations import PublicAPI, override from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, LEARN_ON_BATCH_TIMER, @@ -48,10 +47,10 @@ SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, ) +from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.typing import LearningRateOrSchedule - logger = logging.getLogger(__name__) diff --git a/rllib/algorithms/dreamerv3/dreamerv3_catalog.py b/rllib/algorithms/dreamerv3/dreamerv3_catalog.py index a2cca266ec64..ce16b747ec4d 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_catalog.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_catalog.py @@ -4,11 +4,11 @@ from ray.rllib.algorithms.dreamerv3.utils import ( do_symlog_obs, get_gru_units, - get_num_z_classes, get_num_z_categoricals, + get_num_z_classes, ) -from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.base import Encoder, Model +from ray.rllib.core.models.catalog import Catalog from ray.rllib.utils import override diff --git a/rllib/algorithms/dreamerv3/dreamerv3_learner.py b/rllib/algorithms/dreamerv3/dreamerv3_learner.py index 2bd634ca76e8..b2c0cf27cb22 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_learner.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_learner.py @@ -9,8 +9,8 @@ """ from ray.rllib.core.learner.learner import Learner from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) diff --git a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py index 20e0b8140b16..5cf8f4884a97 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py +++ b/rllib/algorithms/dreamerv3/dreamerv3_rl_module.py @@ -5,21 +5,20 @@ import abc from typing import Dict +from ray.rllib.algorithms.dreamerv3.torch.models.actor_network import ActorNetwork +from ray.rllib.algorithms.dreamerv3.torch.models.critic_network import CriticNetwork +from ray.rllib.algorithms.dreamerv3.torch.models.dreamer_model import DreamerModel +from ray.rllib.algorithms.dreamerv3.torch.models.world_model import WorldModel from ray.rllib.algorithms.dreamerv3.utils import ( do_symlog_obs, get_gru_units, get_num_z_categoricals, get_num_z_classes, ) -from ray.rllib.algorithms.dreamerv3.torch.models.actor_network import ActorNetwork -from ray.rllib.algorithms.dreamerv3.torch.models.critic_network import CriticNetwork -from ray.rllib.algorithms.dreamerv3.torch.models.dreamer_model import DreamerModel -from ray.rllib.algorithms.dreamerv3.torch.models.world_model import WorldModel from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import override from ray.util.annotations import DeveloperAPI - ACTIONS_ONE_HOT = "actions_one_hot" diff --git a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py index a3936253a80b..096fdf7d7fa6 100644 --- a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py +++ b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py @@ -19,6 +19,7 @@ import tree # pip install dm_tree import ray +from ray import tune from ray.rllib.algorithms.dreamerv3 import dreamerv3 from ray.rllib.connectors.env_to_module import FlattenObservations from ray.rllib.core import DEFAULT_MODULE_ID @@ -27,7 +28,6 @@ from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import one_hot from ray.rllib.utils.test_utils import check -from ray import tune torch, nn = try_import_torch() @@ -317,7 +317,8 @@ def test_dreamerv3_dreamer_model_sizes(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py b/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py index f33b9d2deb57..8d9c0ec4ea04 100644 --- a/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py +++ b/rllib/algorithms/dreamerv3/torch/dreamerv3_torch_learner.py @@ -19,7 +19,7 @@ from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import symlog, two_hot, clip_gradients +from ray.rllib.utils.torch_utils import clip_gradients, symlog, two_hot from ray.rllib.utils.typing import ModuleID, TensorType torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/actor_network.py b/rllib/algorithms/dreamerv3/torch/models/actor_network.py index 8a02a41bd9bf..8dc90f4bdf9d 100644 --- a/rllib/algorithms/dreamerv3/torch/models/actor_network.py +++ b/rllib/algorithms/dreamerv3/torch/models/actor_network.py @@ -9,7 +9,6 @@ from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.utils.framework import try_import_torch - torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py b/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py index 14e8a39c829a..64df56079bda 100644 --- a/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py +++ b/rllib/algorithms/dreamerv3/torch/models/components/dynamics_predictor.py @@ -5,10 +5,10 @@ """ from typing import Optional -from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.torch.models.components import ( representation_layer, ) +from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py b/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py index 2733dd2cc132..98f5920f5890 100644 --- a/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py +++ b/rllib/algorithms/dreamerv3/torch/models/components/reward_predictor.py @@ -3,12 +3,11 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.torch.models.components import ( reward_predictor_layer, ) +from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units - from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py b/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py index 1fbd695d54cc..38934a016aa6 100644 --- a/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py +++ b/rllib/algorithms/dreamerv3/torch/models/components/sequence_model.py @@ -11,7 +11,7 @@ dreamerv3_normal_initializer, ) from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP -from ray.rllib.algorithms.dreamerv3.utils import get_gru_units, get_dense_hidden_units +from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units, get_gru_units from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/critic_network.py b/rllib/algorithms/dreamerv3/torch/models/critic_network.py index d4b5798eb55a..f4b4fb956778 100644 --- a/rllib/algorithms/dreamerv3/torch/models/critic_network.py +++ b/rllib/algorithms/dreamerv3/torch/models/critic_network.py @@ -3,11 +3,11 @@ D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap https://arxiv.org/pdf/2301.04104v1.pdf """ -from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units -from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP from ray.rllib.algorithms.dreamerv3.torch.models.components import ( reward_predictor_layer, ) +from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP +from ray.rllib.algorithms.dreamerv3.utils import get_dense_hidden_units from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/dreamerv3/torch/models/world_model.py b/rllib/algorithms/dreamerv3/torch/models/world_model.py index c8851ea8dd71..5e2b4de597f3 100644 --- a/rllib/algorithms/dreamerv3/torch/models/world_model.py +++ b/rllib/algorithms/dreamerv3/torch/models/world_model.py @@ -9,6 +9,9 @@ import numpy as np import tree # pip install dm_tree +from ray.rllib.algorithms.dreamerv3.torch.models.components import ( + representation_layer, +) from ray.rllib.algorithms.dreamerv3.torch.models.components.continue_predictor import ( ContinuePredictor, ) @@ -16,9 +19,6 @@ DynamicsPredictor, ) from ray.rllib.algorithms.dreamerv3.torch.models.components.mlp import MLP -from ray.rllib.algorithms.dreamerv3.torch.models.components import ( - representation_layer, -) from ray.rllib.algorithms.dreamerv3.torch.models.components.reward_predictor import ( RewardPredictor, ) diff --git a/rllib/algorithms/dreamerv3/utils/debugging.py b/rllib/algorithms/dreamerv3/utils/debugging.py index d69281713a38..a99d2923d4ad 100644 --- a/rllib/algorithms/dreamerv3/utils/debugging.py +++ b/rllib/algorithms/dreamerv3/utils/debugging.py @@ -1,8 +1,7 @@ import gymnasium as gym import numpy as np -from PIL import Image, ImageDraw - from gymnasium.envs.classic_control.cartpole import CartPoleEnv +from PIL import Image, ImageDraw from ray.rllib.utils.framework import try_import_torch diff --git a/rllib/algorithms/impala/__init__.py b/rllib/algorithms/impala/__init__.py index 913c1b77198e..f81a5666eb0d 100644 --- a/rllib/algorithms/impala/__init__.py +++ b/rllib/algorithms/impala/__init__.py @@ -1,7 +1,7 @@ from ray.rllib.algorithms.impala.impala import ( IMPALA, - IMPALAConfig, Impala, + IMPALAConfig, ImpalaConfig, ) from ray.rllib.algorithms.impala.impala_tf_policy import ( diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index dc3f14013573..ce0f3d8555ce 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -32,8 +32,8 @@ LEARNER_RESULTS, LEARNER_UPDATE_TIMER, MEAN_NUM_EPISODE_LISTS_RECEIVED, - MEAN_NUM_LEARNER_RESULTS_RECEIVED, MEAN_NUM_LEARNER_GROUP_UPDATE_CALLED, + MEAN_NUM_LEARNER_RESULTS_RECEIVED, NUM_AGENT_STEPS_SAMPLED, NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_SAMPLED, @@ -42,8 +42,8 @@ NUM_ENV_STEPS_TRAINED_LIFETIME, NUM_SYNCH_WORKER_WEIGHTS, NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, ) from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index 9c01248203a0..ecedaf1ce1f7 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -13,8 +13,8 @@ from ray.rllib.core.learner.training_data import TrainingData from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict diff --git a/rllib/algorithms/impala/impala_tf_policy.py b/rllib/algorithms/impala/impala_tf_policy.py index a1f74f48f8ce..94ee60e20260 100644 --- a/rllib/algorithms/impala/impala_tf_policy.py +++ b/rllib/algorithms/impala/impala_tf_policy.py @@ -2,11 +2,12 @@ Keep in sync with changes to A3CTFPolicy and VtraceSurrogatePolicy.""" -import numpy as np import logging -import gymnasium as gym from typing import Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + from ray.rllib.algorithms.impala import vtrace_tf as vtrace from ray.rllib.evaluation.postprocessing import compute_bootstrap_value from ray.rllib.models.modelv2 import ModelV2 @@ -14,12 +15,16 @@ from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import LearningRateSchedule, EntropyCoeffSchedule +from ray.rllib.policy.tf_mixins import ( + EntropyCoeffSchedule, + GradStatsMixin, + LearningRateSchedule, + ValueNetworkMixin, +) from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.tf_utils import explained_variance -from ray.rllib.policy.tf_mixins import GradStatsMixin, ValueNetworkMixin from ray.rllib.utils.typing import ( LocalOptimizer, ModelGradients, diff --git a/rllib/algorithms/impala/impala_torch_policy.py b/rllib/algorithms/impala/impala_torch_policy.py index c174149f7c60..ee58654cab7b 100644 --- a/rllib/algorithms/impala/impala_torch_policy.py +++ b/rllib/algorithms/impala/impala_torch_policy.py @@ -1,12 +1,13 @@ -import gymnasium as gym import logging -import numpy as np from typing import Dict, List, Optional, Type, Union +import gymnasium as gym +import numpy as np + import ray from ray.rllib.evaluation.postprocessing import compute_bootstrap_value -from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.action_dist import ActionDistribution +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import TorchCategorical from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.policy.torch_mixins import ( diff --git a/rllib/algorithms/impala/tests/test_impala.py b/rllib/algorithms/impala/tests/test_impala.py index 868062f019ea..be5ee0eccfb9 100644 --- a/rllib/algorithms/impala/tests/test_impala.py +++ b/rllib/algorithms/impala/tests/test_impala.py @@ -64,7 +64,8 @@ def get_lr(result): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py b/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py index 303797f2f947..d538a032eecb 100644 --- a/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py +++ b/rllib/algorithms/impala/tests/test_vtrace_old_api_stack.py @@ -20,10 +20,11 @@ by Espeholt, Soyer, Munos et al. """ -from gymnasium.spaces import Box -import numpy as np import unittest +import numpy as np +from gymnasium.spaces import Box + from ray.rllib.algorithms.impala import vtrace_torch as vtrace_torch from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.numpy import softmax @@ -282,7 +283,8 @@ def test_inconsistent_rank_inputs_for_importance_weights(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/impala/tests/test_vtrace_v2.py b/rllib/algorithms/impala/tests/test_vtrace_v2.py index 79387104d968..fda785d3df90 100644 --- a/rllib/algorithms/impala/tests/test_vtrace_v2.py +++ b/rllib/algorithms/impala/tests/test_vtrace_v2.py @@ -1,19 +1,19 @@ import unittest -import numpy as np +import numpy as np from gymnasium.spaces import Box, Discrete -from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( - vtrace_torch, - make_time_major, -) from ray.rllib.algorithms.impala.tests.test_vtrace_old_api_stack import ( _ground_truth_vtrace_calculation, ) -from ray.rllib.utils.torch_utils import convert_to_torch_tensor +from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( + make_time_major, + vtrace_torch, +) from ray.rllib.core.distribution.torch.torch_distribution import TorchCategorical from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.test_utils import check +from ray.rllib.utils.torch_utils import convert_to_torch_tensor torch, _ = try_import_torch() @@ -147,7 +147,8 @@ def test_vtrace_torch(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index 256e3b48fb79..f52250d16a1c 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -3,8 +3,8 @@ from ray.rllib.algorithms.impala.impala import IMPALAConfig from ray.rllib.algorithms.impala.impala_learner import IMPALALearner from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( - vtrace_torch, make_time_major, + vtrace_torch, ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ENTROPY_KEY diff --git a/rllib/algorithms/impala/torch/vtrace_torch_v2.py b/rllib/algorithms/impala/torch/vtrace_torch_v2.py index 48231be9d7d5..bf4c4fa99373 100644 --- a/rllib/algorithms/impala/torch/vtrace_torch_v2.py +++ b/rllib/algorithms/impala/torch/vtrace_torch_v2.py @@ -1,4 +1,5 @@ from typing import List, Union + from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() diff --git a/rllib/algorithms/iql/default_iql_rl_module.py b/rllib/algorithms/iql/default_iql_rl_module.py index e6e3b2279ac5..95596bd8b91d 100644 --- a/rllib/algorithms/iql/default_iql_rl_module.py +++ b/rllib/algorithms/iql/default_iql_rl_module.py @@ -2,8 +2,8 @@ from ray.rllib.core.models.configs import MLPHeadConfig from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) diff --git a/rllib/algorithms/iql/iql_learner.py b/rllib/algorithms/iql/iql_learner.py index 5821f2ccb5e0..ef2e2e83e15b 100644 --- a/rllib/algorithms/iql/iql_learner.py +++ b/rllib/algorithms/iql/iql_learner.py @@ -2,8 +2,8 @@ from ray.rllib.algorithms.dqn.dqn_learner import DQNLearner from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.typing import ModuleID, TensorType diff --git a/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py b/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py index 00d7fc821e49..318dcc207533 100644 --- a/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py +++ b/rllib/algorithms/iql/torch/default_iql_torch_rl_module.py @@ -1,8 +1,9 @@ -import gymnasium as gym from typing import Any, Dict, Optional +import gymnasium as gym + from ray.rllib.algorithms.iql.default_iql_rl_module import DefaultIQLRLModule -from ray.rllib.algorithms.iql.iql_learner import VF_PREDS_NEXT, QF_TARGET_PREDS +from ray.rllib.algorithms.iql.iql_learner import QF_TARGET_PREDS, VF_PREDS_NEXT from ray.rllib.algorithms.sac.torch.default_sac_torch_rl_module import ( DefaultSACTorchRLModule, ) diff --git a/rllib/algorithms/iql/torch/iql_torch_learner.py b/rllib/algorithms/iql/torch/iql_torch_learner.py index 85dc68e86fb2..54a4fd263caa 100644 --- a/rllib/algorithms/iql/torch/iql_torch_learner.py +++ b/rllib/algorithms/iql/torch/iql_torch_learner.py @@ -1,14 +1,14 @@ from typing import Dict from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.algorithms.dqn.dqn_learner import QF_PREDS, QF_LOSS_KEY +from ray.rllib.algorithms.dqn.dqn_learner import QF_LOSS_KEY, QF_PREDS from ray.rllib.algorithms.iql.iql_learner import ( - IQLLearner, QF_TARGET_PREDS, - VF_PREDS_NEXT, VF_LOSS, + VF_PREDS_NEXT, + IQLLearner, ) -from ray.rllib.algorithms.sac.sac_learner import QF_TWIN_PREDS, QF_TWIN_LOSS_KEY +from ray.rllib.algorithms.sac.sac_learner import QF_TWIN_LOSS_KEY, QF_TWIN_PREDS from ray.rllib.core import ALL_MODULES from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( diff --git a/rllib/algorithms/marwil/marwil.py b/rllib/algorithms/marwil/marwil.py index 5685e4df127f..e54843213e64 100644 --- a/rllib/algorithms/marwil/marwil.py +++ b/rllib/algorithms/marwil/marwil.py @@ -1,12 +1,14 @@ from typing import Callable, Optional, Type, Union + from typing_extensions import Self +from ray._common.deprecation import deprecation_warning from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.connectors.learner import ( + AddNextObservationsFromEpisodesToTrainBatch, AddObservationsFromEpisodesToBatch, AddOneTsToEpisodesAndTruncate, - AddNextObservationsFromEpisodesToTrainBatch, GeneralAdvantageEstimation, ) from ray.rllib.core.learner.learner import Learner @@ -21,7 +23,6 @@ ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import deprecation_warning from ray.rllib.utils.metrics import ( LEARNER_RESULTS, LEARNER_UPDATE_TIMER, diff --git a/rllib/algorithms/marwil/marwil_learner.py b/rllib/algorithms/marwil/marwil_learner.py index 363e6a84a309..b98b0d090f66 100644 --- a/rllib/algorithms/marwil/marwil_learner.py +++ b/rllib/algorithms/marwil/marwil_learner.py @@ -1,7 +1,7 @@ from typing import Dict, Optional -from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.apis import ValueFunctionAPI from ray.rllib.utils.annotations import override from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.typing import ModuleID, ShouldModuleBeUpdatedFn, TensorType diff --git a/rllib/algorithms/marwil/marwil_tf_policy.py b/rllib/algorithms/marwil/marwil_tf_policy.py index 5f75a8424c76..2dbdb6a0efd2 100644 --- a/rllib/algorithms/marwil/marwil_tf_policy.py +++ b/rllib/algorithms/marwil/marwil_tf_policy.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict, List, Optional, Type, Union -from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing +from ray.rllib.evaluation.postprocessing import Postprocessing, compute_advantages from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution @@ -14,7 +14,7 @@ compute_gradients, ) from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf, get_variable +from ray.rllib.utils.framework import get_variable, try_import_tf from ray.rllib.utils.tf_utils import explained_variance from ray.rllib.utils.typing import ( LocalOptimizer, diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index 3bfc8ff30231..5f39cf9752c0 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -1,11 +1,12 @@ +import unittest +from pathlib import Path + import gymnasium as gym import numpy as np -from pathlib import Path -import unittest import ray import ray.rllib.algorithms.marwil as marwil -from ray.rllib.core import DEFAULT_MODULE_ID, COMPONENT_RL_MODULE +from ray.rllib.core import COMPONENT_RL_MODULE, DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY from ray.rllib.env import INPUT_ENV_SPACES @@ -232,7 +233,8 @@ def possibly_masked_mean(data_): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/marwil/tests/test_marwil_rl_module.py b/rllib/algorithms/marwil/tests/test_marwil_rl_module.py index 683180d0609a..8ea50e5be7f3 100644 --- a/rllib/algorithms/marwil/tests/test_marwil_rl_module.py +++ b/rllib/algorithms/marwil/tests/test_marwil_rl_module.py @@ -1,9 +1,9 @@ import itertools import unittest -import ray - from pathlib import Path +import ray + class TestMARWIL(unittest.TestCase): @classmethod @@ -31,7 +31,8 @@ def test_rollouts(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/mock.py b/rllib/algorithms/mock.py index 25707cf1677b..ba2ac262af21 100644 --- a/rllib/algorithms/mock.py +++ b/rllib/algorithms/mock.py @@ -4,9 +4,9 @@ import numpy as np -from ray.tune import result as tune_result from ray.rllib.algorithms.algorithm import Algorithm, AlgorithmConfig from ray.rllib.utils.annotations import override +from ray.tune import result as tune_result class _MockTrainer(Algorithm): diff --git a/rllib/algorithms/ppo/__init__.py b/rllib/algorithms/ppo/__init__.py index a02982e64a53..9ed907f5dd1e 100644 --- a/rllib/algorithms/ppo/__init__.py +++ b/rllib/algorithms/ppo/__init__.py @@ -1,4 +1,4 @@ -from ray.rllib.algorithms.ppo.ppo import PPOConfig, PPO +from ray.rllib.algorithms.ppo.ppo import PPO, PPOConfig from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy, PPOTF2Policy from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy diff --git a/rllib/algorithms/ppo/default_ppo_rl_module.py b/rllib/algorithms/ppo/default_ppo_rl_module.py index 1216eeef0d75..5ac176452f36 100644 --- a/rllib/algorithms/ppo/default_ppo_rl_module.py +++ b/rllib/algorithms/ppo/default_ppo_rl_module.py @@ -5,8 +5,8 @@ from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, ValueFunctionAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.util.annotations import DeveloperAPI diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 407565fc4383..d94bd0a5c6f3 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -10,9 +10,11 @@ """ import logging -from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union + from typing_extensions import Self +from ray._common.deprecation import DEPRECATED_VALUE from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.core.rl_module.rl_module import RLModuleSpec @@ -21,13 +23,13 @@ synchronous_parallel_sample, ) from ray.rllib.execution.train_ops import ( - train_one_step, multi_gpu_train_one_step, + train_one_step, ) from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import OldAPIStack, override -from ray._common.deprecation import DEPRECATED_VALUE from ray.rllib.utils.metrics import ( + ALL_MODULES, ENV_RUNNER_RESULTS, ENV_RUNNER_SAMPLING_TIMER, LEARNER_RESULTS, @@ -35,10 +37,9 @@ NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED_LIFETIME, - SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, + SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, - ALL_MODULES, ) from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.schedules.scheduler import Scheduler diff --git a/rllib/algorithms/ppo/ppo_catalog.py b/rllib/algorithms/ppo/ppo_catalog.py index fb11efea17ba..e88b761427a2 100644 --- a/rllib/algorithms/ppo/ppo_catalog.py +++ b/rllib/algorithms/ppo/ppo_catalog.py @@ -1,13 +1,13 @@ # __sphinx_doc_begin__ import gymnasium as gym +from ray.rllib.core.models.base import ActorCriticEncoder, Encoder, Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( ActorCriticEncoderConfig, - MLPHeadConfig, FreeLogStdMLPHeadConfig, + MLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model from ray.rllib.utils import override from ray.rllib.utils.annotations import OverrideToImplementCustomLogic diff --git a/rllib/algorithms/ppo/ppo_learner.py b/rllib/algorithms/ppo/ppo_learner.py index b6d3953a8a45..ef16f71c98bb 100644 --- a/rllib/algorithms/ppo/ppo_learner.py +++ b/rllib/algorithms/ppo/ppo_learner.py @@ -13,8 +13,8 @@ from ray.rllib.core.learner.learner import Learner from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic_CallToSuperRecommended, + override, ) from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.metrics import ( diff --git a/rllib/algorithms/ppo/tests/test_ppo.py b/rllib/algorithms/ppo/tests/test_ppo.py index a0544a566944..6531d2d8f5cf 100644 --- a/rllib/algorithms/ppo/tests/test_ppo.py +++ b/rllib/algorithms/ppo/tests/test_ppo.py @@ -161,7 +161,8 @@ def get_value(log_std_var=log_std_var): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/tests/test_ppo_learner.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py index 1d5f83639bb9..825b1411b948 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_learner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -1,5 +1,5 @@ -import unittest import tempfile +import unittest import gymnasium as gym import numpy as np @@ -13,7 +13,6 @@ from ray.rllib.utils.test_utils import check from ray.tune.registry import register_env - # Fake CartPole episode of n time steps. FAKE_BATCH = { Columns.OBS: np.array( @@ -136,7 +135,8 @@ def test_kl_coeff_changes(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index c7d786d3d1a5..a8d5999a586d 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -17,7 +17,6 @@ from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.torch_utils import convert_to_torch_tensor - torch, nn = try_import_torch() @@ -186,7 +185,8 @@ def test_forward_train(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index 190ecbf106c1..4e7a806f98ab 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -4,15 +4,15 @@ import numpy as np from ray.rllib.algorithms.ppo.ppo import ( - LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_CURR_KL_COEFF_KEY, + LEARNER_RESULTS_KL_KEY, LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, PPOConfig, ) from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner from ray.rllib.core.columns import Columns -from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.learner.learner import ENTROPY_KEY, POLICY_LOSS_KEY, VF_LOSS_KEY from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.utils.annotations import override diff --git a/rllib/algorithms/sac/default_sac_rl_module.py b/rllib/algorithms/sac/default_sac_rl_module.py index 3d01e5ed5ccc..76f02a1e4c7f 100644 --- a/rllib/algorithms/sac/default_sac_rl_module.py +++ b/rllib/algorithms/sac/default_sac_rl_module.py @@ -6,8 +6,8 @@ from ray.rllib.core.rl_module.apis import InferenceOnlyAPI, QNetAPI, TargetNetworkAPI from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import ( - override, OverrideToImplementCustomLogic, + override, ) from ray.rllib.utils.typing import NetworkType from ray.util.annotations import DeveloperAPI diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 7d9c6852ed90..d464e95889db 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -1,7 +1,9 @@ import logging from typing import Any, Dict, Optional, Tuple, Type, Union + from typing_extensions import Self +from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.dqn.dqn import DQN from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy @@ -16,7 +18,6 @@ from ray.rllib.policy.policy import Policy from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override -from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning from ray.rllib.utils.framework import try_import_tf, try_import_tfp from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.typing import LearningRateOrSchedule, RLModuleSpecType diff --git a/rllib/algorithms/sac/sac_catalog.py b/rllib/algorithms/sac/sac_catalog.py index c60bfd77992f..2ebe4470af18 100644 --- a/rllib/algorithms/sac/sac_catalog.py +++ b/rllib/algorithms/sac/sac_catalog.py @@ -1,23 +1,24 @@ +from typing import Callable + import gymnasium as gym import numpy as np -from typing import Callable # TODO (simon): Store this function somewhere more central as many # algorithms will use it. from ray.rllib.algorithms.ppo.ppo_catalog import _check_if_diag_gaussian +from ray.rllib.core.distribution.distribution import Distribution +from ray.rllib.core.distribution.torch.torch_distribution import ( + TorchCategorical, + TorchSquashedGaussian, +) +from ray.rllib.core.models.base import Encoder, Model from ray.rllib.core.models.catalog import Catalog from ray.rllib.core.models.configs import ( FreeLogStdMLPHeadConfig, MLPEncoderConfig, MLPHeadConfig, ) -from ray.rllib.core.models.base import Encoder, Model -from ray.rllib.core.distribution.torch.torch_distribution import ( - TorchSquashedGaussian, - TorchCategorical, -) -from ray.rllib.utils.annotations import override, OverrideToImplementCustomLogic -from ray.rllib.core.distribution.distribution import Distribution +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic, override # TODO (simon): Check, if we can directly derive from DQNCatalog. diff --git a/rllib/algorithms/sac/sac_learner.py b/rllib/algorithms/sac/sac_learner.py index f4108943ad04..abbf082b1ca1 100644 --- a/rllib/algorithms/sac/sac_learner.py +++ b/rllib/algorithms/sac/sac_learner.py @@ -1,7 +1,7 @@ -import numpy as np - from typing import Dict +import numpy as np + from ray.rllib.algorithms.dqn.dqn_learner import DQNLearner from ray.rllib.core.learner.learner import Learner from ray.rllib.utils.annotations import override diff --git a/rllib/algorithms/sac/sac_tf_model.py b/rllib/algorithms/sac/sac_tf_model.py index 7302a25fcccf..e3b3479ff684 100644 --- a/rllib/algorithms/sac/sac_tf_model.py +++ b/rllib/algorithms/sac/sac_tf_model.py @@ -1,15 +1,16 @@ +from typing import Dict, List, Optional + import gymnasium as gym -from gymnasium.spaces import Box, Discrete import numpy as np import tree # pip install dm_tree -from typing import Dict, List, Optional +from gymnasium.spaces import Box, Discrete from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.typing import ModelConfigDict, TensorType, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, TensorStructType, TensorType tf1, tf, tfv = try_import_tf() diff --git a/rllib/algorithms/sac/sac_tf_policy.py b/rllib/algorithms/sac/sac_tf_policy.py index 2ce3184c70d9..e4322518e46a 100644 --- a/rllib/algorithms/sac/sac_tf_policy.py +++ b/rllib/algorithms/sac/sac_tf_policy.py @@ -3,16 +3,17 @@ """ import copy -import gymnasium as gym -from gymnasium.spaces import Box, Discrete -from functools import partial import logging +from functools import partial from typing import Dict, List, Optional, Tuple, Type, Union +import gymnasium as gym +from gymnasium.spaces import Box, Discrete + import ray from ray.rllib.algorithms.dqn.dqn_tf_policy import ( - postprocess_nstep_and_prio, PRIO_WEIGHTS, + postprocess_nstep_and_prio, ) from ray.rllib.algorithms.sac.sac_tf_model import SACTFModel from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel @@ -36,10 +37,10 @@ from ray.rllib.utils.tf_utils import huber_loss, make_tf_callable from ray.rllib.utils.typing import ( AgentID, + AlgorithmConfigDict, LocalOptimizer, ModelGradients, TensorType, - AlgorithmConfigDict, ) tf1, tf, tfv = try_import_tf() diff --git a/rllib/algorithms/sac/sac_torch_model.py b/rllib/algorithms/sac/sac_torch_model.py index 00219fd95b8a..8c2fcd5b530c 100644 --- a/rllib/algorithms/sac/sac_torch_model.py +++ b/rllib/algorithms/sac/sac_torch_model.py @@ -1,15 +1,16 @@ +from typing import Dict, List, Optional + import gymnasium as gym -from gymnasium.spaces import Box, Discrete import numpy as np import tree # pip install dm_tree -from typing import Dict, List, Optional +from gymnasium.spaces import Box, Discrete from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.utils.typing import ModelConfigDict, TensorType, TensorStructType +from ray.rllib.utils.typing import ModelConfigDict, TensorStructType, TensorType torch, nn = try_import_torch() diff --git a/rllib/algorithms/sac/sac_torch_policy.py b/rllib/algorithms/sac/sac_torch_policy.py index cef30f465f5d..b105b856ed0b 100644 --- a/rllib/algorithms/sac/sac_torch_policy.py +++ b/rllib/algorithms/sac/sac_torch_policy.py @@ -2,45 +2,46 @@ PyTorch policy class used for SAC. """ -import gymnasium as gym -from gymnasium.spaces import Box, Discrete import logging -import tree # pip install dm_tree from typing import Dict, List, Optional, Tuple, Type, Union +import gymnasium as gym +import tree # pip install dm_tree +from gymnasium.spaces import Box, Discrete + import ray +from ray.rllib.algorithms.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.algorithms.sac.sac_tf_policy import ( build_sac_model, postprocess_trajectory, validate_spaces, ) -from ray.rllib.algorithms.dqn.dqn_tf_policy import PRIO_WEIGHTS from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.torch.torch_action_dist import ( + TorchBeta, TorchCategorical, - TorchDistributionWrapper, + TorchDiagGaussian, TorchDirichlet, + TorchDistributionWrapper, TorchSquashedGaussian, - TorchDiagGaussian, - TorchBeta, ) from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_template import build_policy_class from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.spaces.simplex import Simplex -from ray.rllib.policy.torch_mixins import TargetNetworkMixin from ray.rllib.utils.torch_utils import ( apply_grad_clipping, concat_multi_gpu_td_errors, huber_loss, ) from ray.rllib.utils.typing import ( + AlgorithmConfigDict, LocalOptimizer, ModelInputDict, TensorType, - AlgorithmConfigDict, ) torch, nn = try_import_torch() diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index be49d2fd5f81..4d03ba92db63 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -1,16 +1,17 @@ +import unittest + import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete, Tuple import numpy as np -import unittest +from gymnasium.spaces import Box, Dict, Discrete, Tuple import ray +from ray import tune from ray.rllib.algorithms import sac from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations from ray.rllib.examples.envs.classes.random_env import RandomEnv from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.spaces.simplex import Simplex from ray.rllib.utils.test_utils import check_train_results_new_api_stack -from ray import tune tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -179,7 +180,8 @@ def step(self, action): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py b/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py index 0612dce7c391..09a0f6091ab1 100644 --- a/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py +++ b/rllib/algorithms/sac/torch/default_sac_torch_rl_module.py @@ -1,17 +1,18 @@ -import gymnasium as gym from typing import Any, Dict +import gymnasium as gym + from ray.rllib.algorithms.sac.default_sac_rl_module import DefaultSACRLModule from ray.rllib.algorithms.sac.sac_catalog import SACCatalog from ray.rllib.algorithms.sac.sac_learner import ( ACTION_DIST_INPUTS_NEXT, - QF_PREDS, - QF_TWIN_PREDS, - QF_TARGET_NEXT, + ACTION_LOG_PROBS, ACTION_LOG_PROBS_NEXT, - ACTION_PROBS_NEXT, ACTION_PROBS, - ACTION_LOG_PROBS, + ACTION_PROBS_NEXT, + QF_PREDS, + QF_TARGET_NEXT, + QF_TWIN_PREDS, ) from ray.rllib.core.columns import Columns from ray.rllib.core.models.base import ENCODER_OUT, Encoder, Model diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index 8d96f22f4730..478970795d85 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -1,24 +1,25 @@ -import gymnasium as gym from typing import Any, Dict +import gymnasium as gym + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import DQNTorchLearner from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.algorithms.sac.sac_learner import ( + ACTION_LOG_PROBS, + ACTION_LOG_PROBS_NEXT, + ACTION_PROBS, + ACTION_PROBS_NEXT, LOGPS_KEY, QF_LOSS_KEY, - QF_MEAN_KEY, QF_MAX_KEY, + QF_MEAN_KEY, QF_MIN_KEY, QF_PREDS, + QF_TARGET_NEXT, QF_TWIN_LOSS_KEY, QF_TWIN_PREDS, TD_ERROR_MEAN_KEY, - ACTION_LOG_PROBS, - ACTION_LOG_PROBS_NEXT, - ACTION_PROBS, - ACTION_PROBS_NEXT, - QF_TARGET_NEXT, SACLearner, ) from ray.rllib.core.columns import Columns diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 39f583f5f722..e31ff3999271 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -1,15 +1,16 @@ -import gymnasium as gym -import numpy as np import os +import unittest from pathlib import Path from random import choice -import unittest + +import gymnasium as gym +import numpy as np import ray -from ray.rllib.algorithms.algorithm import Algorithm import ray.rllib.algorithms.dqn as dqn -from ray.rllib.algorithms.bc import BCConfig import ray.rllib.algorithms.ppo as ppo +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.bc import BCConfig from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.core.rl_module.rl_module import RLModuleSpec @@ -615,7 +616,8 @@ def _assert_modules_added( if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index afe48b1117c5..6b6c381f6e56 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -1,17 +1,18 @@ -import gymnasium as gym -from typing import Type import unittest +from typing import Type + +import gymnasium as gym import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo import PPO, PPOConfig from ray.rllib.algorithms.ppo.torch.ppo_torch_learner import PPOTorchLearner from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModule from ray.rllib.core.rl_module.multi_rl_module import ( MultiRLModule, MultiRLModuleSpec, ) +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.utils.test_utils import check @@ -432,7 +433,8 @@ def get_default_rl_module_spec(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_export_checkpoint.py b/rllib/algorithms/tests/test_algorithm_export_checkpoint.py index 116b68399aee..e978dc961b55 100644 --- a/rllib/algorithms/tests/test_algorithm_export_checkpoint.py +++ b/rllib/algorithms/tests/test_algorithm_export_checkpoint.py @@ -1,8 +1,9 @@ -import numpy as np import os import shutil import unittest +import numpy as np + import ray import ray._common from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -96,7 +97,8 @@ def test_save_appo_multi_agent(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_imports.py b/rllib/algorithms/tests/test_algorithm_imports.py index f528f082e19c..352dd41d9880 100644 --- a/rllib/algorithms/tests/test_algorithm_imports.py +++ b/rllib/algorithms/tests/test_algorithm_imports.py @@ -17,7 +17,8 @@ def test_algo_import(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py index bc13d04567f5..7cbb1ec39269 100644 --- a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py +++ b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py @@ -1,25 +1,25 @@ -import gymnasium as gym -import numpy as np import shutil import tempfile -import tree import unittest +import gymnasium as gym +import numpy as np +import tree + import ray from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig -from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.rl_module.multi_rl_module import ( - MultiRLModuleSpec, MultiRLModule, + MultiRLModuleSpec, ) +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole -from ray.rllib.utils.test_utils import check from ray.rllib.utils.numpy import convert_to_numpy - +from ray.rllib.utils.test_utils import check NUM_AGENTS = 2 @@ -329,7 +329,8 @@ def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py index ec58a5376faa..3ede13e17215 100644 --- a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py +++ b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_connectors.py @@ -1,7 +1,6 @@ import tempfile import unittest - import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo import PPOConfig @@ -10,7 +9,6 @@ from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.test_utils import check - algorithms_and_configs = { "PPO": (PPOConfig().training(train_batch_size=2, minibatch_size=2)) } @@ -228,6 +226,7 @@ def _assert_running_stats_consistency( if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py index 78840a3fe4be..1d0134cc2769 100644 --- a/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py +++ b/rllib/algorithms/tests/test_algorithm_save_load_checkpoint_learner.py @@ -7,7 +7,6 @@ from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.utils.metrics import LEARNER_RESULTS - algorithms_and_configs = { "PPO": (PPOConfig().training(train_batch_size=2, minibatch_size=2)) } @@ -126,6 +125,7 @@ def test_save_and_restore(self): if __name__ == "__main__": import sys + import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_env_runner_failures.py b/rllib/algorithms/tests/test_env_runner_failures.py index 72808d517522..abc3c7d1e3f9 100644 --- a/rllib/algorithms/tests/test_env_runner_failures.py +++ b/rllib/algorithms/tests/test_env_runner_failures.py @@ -1,14 +1,15 @@ +import time +import unittest from collections import defaultdict + import gymnasium as gym import numpy as np -import time -import unittest import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.impala import IMPALAConfig -from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.sac.sac import SACConfig from ray.rllib.connectors.env_to_module.flatten_observations import FlattenObservations from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig from ray.rllib.env.multi_agent_env import make_multi_agent diff --git a/rllib/algorithms/tests/test_node_failures.py b/rllib/algorithms/tests/test_node_failures.py index 34e6560a8aae..7e1350024740 100644 --- a/rllib/algorithms/tests/test_node_failures.py +++ b/rllib/algorithms/tests/test_node_failures.py @@ -14,7 +14,6 @@ MODULE_TRAIN_BATCH_SIZE_MEAN, ) - object_store_memory = 10**8 num_nodes = 3 @@ -193,7 +192,8 @@ def _train(self, *, config, iters, min_reward, preempt_freq): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/tests/test_registry.py b/rllib/algorithms/tests/test_registry.py index 85e8029691ba..534f79327ede 100644 --- a/rllib/algorithms/tests/test_registry.py +++ b/rllib/algorithms/tests/test_registry.py @@ -1,11 +1,11 @@ import unittest from ray.rllib.algorithms.registry import ( + ALGORITHMS, + ALGORITHMS_CLASS_TO_NAME, POLICIES, get_policy_class, get_policy_class_name, - ALGORITHMS_CLASS_TO_NAME, - ALGORITHMS, ) @@ -31,7 +31,8 @@ def test_registered_algorithm_names(self): if __name__ == "__main__": - import pytest import sys + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/utils/from_config.py b/rllib/utils/from_config.py index 30b7290999fc..3f80e785265b 100644 --- a/rllib/utils/from_config.py +++ b/rllib/utils/from_config.py @@ -4,14 +4,13 @@ import re from copy import deepcopy from functools import partial +from typing import TYPE_CHECKING, Optional import yaml from ray.rllib.utils import force_list, merge_dicts from ray.rllib.utils.annotations import DeveloperAPI -from typing import Optional, TYPE_CHECKING - if TYPE_CHECKING: from ray.rllib.utils.typing import FromConfigSpec diff --git a/rllib/utils/replay_buffers/__init__.py b/rllib/utils/replay_buffers/__init__.py index e929ab7d5988..c5f53f25e3e3 100644 --- a/rllib/utils/replay_buffers/__init__.py +++ b/rllib/utils/replay_buffers/__init__.py @@ -1,11 +1,11 @@ from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils.replay_buffers.fifo_replay_buffer import FifoReplayBuffer -from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import ( - MultiAgentMixInReplayBuffer, -) from ray.rllib.utils.replay_buffers.multi_agent_episode_buffer import ( MultiAgentEpisodeReplayBuffer, ) +from ray.rllib.utils.replay_buffers.multi_agent_mixin_replay_buffer import ( + MultiAgentMixInReplayBuffer, +) from ray.rllib.utils.replay_buffers.multi_agent_prioritized_episode_buffer import ( MultiAgentPrioritizedEpisodeReplayBuffer, ) @@ -24,7 +24,8 @@ ) from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer, StorageUnit from ray.rllib.utils.replay_buffers.reservoir_replay_buffer import ReservoirReplayBuffer -from ray.rllib.utils.replay_buffers import utils + +from ray.rllib.utils.replay_buffers import utils # isort: skip __all__ = [ "EpisodeReplayBuffer", diff --git a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py index 9deb9e7f1387..1077c4300c22 100644 --- a/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py +++ b/rllib/utils/replay_buffers/tests/test_segment_tree_replay_buffer_api.py @@ -1,8 +1,9 @@ import unittest + import numpy as np from ray.rllib.env.single_agent_episode import SingleAgentEpisode -from ray.rllib.execution.segment_tree import SumSegmentTree, MinSegmentTree +from ray.rllib.execution.segment_tree import MinSegmentTree, SumSegmentTree from ray.rllib.utils.replay_buffers import PrioritizedEpisodeReplayBuffer diff --git a/rllib/utils/runners/runner.py b/rllib/utils/runners/runner.py index 6d40319b63d7..fb3a8b61d278 100644 --- a/rllib/utils/runners/runner.py +++ b/rllib/utils/runners/runner.py @@ -1,6 +1,5 @@ import abc import logging - from typing import TYPE_CHECKING, Any, Union from ray.rllib.utils.actor_manager import FaultAwareApply