diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index a26174b863..4b1ac1710a 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -76,12 +76,12 @@ jobs: TYPE="${{ steps.test_type.outputs.type }}" if [ "$TYPE" = "all" ]; then echo "tests_run=true" >> $GITHUB_ENV - docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json + docker compose exec trinity-node-1 pytest tests -v -s --ctrf report.json elif [ "$TYPE" = "diff" ]; then if [ -s ../../../test_dirs.txt ]; then echo "tests_run=true" >> $GITHUB_ENV TEST_DIRS=$(cat ../../../test_dirs.txt | xargs) - docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json + docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ctrf report.json else echo "No changed modules detected, skipping tests." echo "tests_run=false" >> $GITHUB_ENV @@ -90,7 +90,7 @@ jobs: MODULE="${{ steps.test_type.outputs.module }}" if [ -n "$MODULE" ]; then echo "tests_run=true" >> $GITHUB_ENV - docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ignore=tests/data --ctrf report.json + docker compose exec trinity-node-1 pytest tests/$MODULE -v -s --ctrf report.json else echo "No module specified, skipping tests." echo "tests_run=false" >> $GITHUB_ENV diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index a397500aee..0669fdcbd6 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -40,6 +40,14 @@ monitor: data_processor: # Preprocessing data settings ... + +service: + # Services to use + ... + +log: + # Ray actor logging + ... ``` Each of these sections will be explained in detail below. @@ -395,28 +403,36 @@ trainer: --- -## Data Processor Configuration +## Service Configuration -Configures preprocessing and data cleaning pipelines. +Configures services used by Trinity-RFT. Only support Data Juicer service for now. ```yaml -data_processor: - source_data_path: /PATH/TO/DATASET - load_kwargs: - split: 'train' - format: - prompt_key: 'question' - response_key: 'answer' - dj_config_path: 'tests/test_configs/active_iterator_test_dj_cfg.yaml' - clean_strategy: 'iterative' - db_url: 'postgresql://{username}@localhost:5432/{db_name}' +service: + data_juicer: + server_url: 'http://127.0.0.1:5005' + auto_start: true + port: 5005 +``` + +- `server_url`: The url of data juicer server. +- `auto_start`: Whether to automatically start the data juicer service. +- `port`: The port for Data Juicer service when `auto_start` is true. + +-- + +## Log Configuration + +Ray actor logging configuration. + +```yaml +log: + level: INFO + group_by_node: False ``` -- `source_data_path`: Path to the task dataset. -- `load_kwargs`: Arguments passed to HuggingFace’s `load_dataset()`. -- `dj_config_path`: Path to Data-Juicer configuration for cleaning. -- `clean_strategy`: Strategy for iterative data cleaning. -- `db_url`: Database URL if using SQL backend. +- `level`: The logging level (supports `DEBUG`, `INFO`, `WARNING`, `ERROR`). +- `group_by_node`: Whether to group logs by node IP. If set to `True`, an actor's logs will be save to `///log//.log`, otherwise it will be saved to `///log/.log`. --- diff --git a/tests/tools.py b/tests/tools.py index 58108faa77..ccadb80ac5 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -13,7 +13,10 @@ def get_template_config() -> Config: config_path = os.path.join(os.path.dirname(__file__), "template", "config.yaml") config = load_config(config_path) - config.ray_namespace = ray.get_runtime_context().namespace + if ray.is_initialized(): + config.ray_namespace = ray.get_runtime_context().namespace + else: + config.ray_namespace = "trinity_unittest" return config diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 1b0207a494..190d3bf06a 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -21,7 +21,13 @@ ) from trinity.cli.launcher import bench, both, explore, train from trinity.common.config import Config, StorageConfig -from trinity.common.constants import StorageType, SyncMethod, SyncStyle +from trinity.common.constants import ( + LOG_DIR_ENV_VAR, + LOG_LEVEL_ENV_VAR, + StorageType, + SyncMethod, + SyncStyle, +) from trinity.common.models.utils import get_checkpoint_dir_with_step_num from trinity.manager.manager import CacheManager @@ -355,12 +361,28 @@ def tearDown(self): def run_trainer(config: Config) -> None: - ray.init(namespace=config.ray_namespace) + ray.init( + namespace=config.ray_namespace, + runtime_env={ + "env_vars": { + LOG_DIR_ENV_VAR: config.log.save_dir, + LOG_LEVEL_ENV_VAR: "INFO", + } + }, + ) train(config) def run_explorer(config: Config) -> None: - ray.init(namespace=config.ray_namespace) + ray.init( + namespace=config.ray_namespace, + runtime_env={ + "env_vars": { + LOG_DIR_ENV_VAR: config.log.save_dir, + LOG_LEVEL_ENV_VAR: "INFO", + } + }, + ) explore(config) @@ -487,6 +509,22 @@ def test_fully_async_mode(self, name, use_priority_queue): )[1], 8, ) + log_files = os.listdir(os.path.join(explorer1_config.checkpoint_job_dir, "log")) + self.assertTrue("trainer.log" in log_files) + self.assertTrue("synchronizer.log" in log_files) + self.assertTrue("explorer1.log" in log_files) + self.assertTrue("explorer2.log" in log_files) + self.assertTrue("explorer1_runner_0.log" in log_files) + self.assertTrue("explorer1_runner_7.log" in log_files) + self.assertTrue("explorer2_runner_0.log" in log_files) + self.assertTrue("explorer2_runner_7.log" in log_files) + self.assertTrue("explorer1_experience_pipeline.log" in log_files) + self.assertTrue("explorer2_experience_pipeline.log" in log_files) + files_to_check = ["trainer.log", "synchronizer.log", "explorer1.log", "explorer2.log"] + for file_name in files_to_check: + with open(os.path.join(explorer1_config.checkpoint_job_dir, "log", file_name)) as f: + lines = f.readlines() + self.assertTrue(len(lines) > 0) ray.shutdown() def tearDown(self): diff --git a/tests/utils/log_test.py b/tests/utils/log_test.py new file mode 100644 index 0000000000..2f1080ed4e --- /dev/null +++ b/tests/utils/log_test.py @@ -0,0 +1,166 @@ +import logging +import os +import shutil +import unittest + +import ray +from ray.runtime_env import RuntimeEnv + +from tests.tools import get_template_config +from trinity.common.constants import ( + LOG_DIR_ENV_VAR, + LOG_LEVEL_ENV_VAR, + LOG_NODE_IP_ENV_VAR, +) +from trinity.utils.log import get_logger + + +def log_outside_actor(log_level=logging.INFO): + logger = get_logger("outside_actor", level=log_level) + logger.info("Outside logger initialized") + logger.debug("Outside logger initialized") + + +class ModuleInActor: + def __init__(self): + self.logger = get_logger("module_in_actor", in_ray_actor=True) + self.logger.info("ModuleInActor initialized") + self.logger.debug("ModuleInActor initialized") + + +class ModuleInActor2: + def __init__(self): + # module create in actor should automatically inherit the logger created by the root actor + self.logger = get_logger("module_in_actor2") + self.logger.info("ModuleInActor2 initialized") + self.logger.debug("ModuleInActor2 initialized") + + +@ray.remote +class ActorInActor: + """An actor created inside an actor""" + + def __init__(self, parent_name, log_level): + self.logger = get_logger(f"{parent_name}_nested", in_ray_actor=True, level=log_level) + self.logger.info("ActorInActor initialized") + self.logger.debug("ActorInActor initialized") + + +@ray.remote +class LogActor: + def __init__(self, aid: int, log_level=logging.INFO): + assert os.environ.get(LOG_DIR_ENV_VAR) is not None, "LOG_DIR_ENV_VAR must be set" + self.logger = get_logger(f"actor_{aid}", in_ray_actor=True, level=log_level) + self.logger.info(f"LogActor {aid} initialized ") + self.logger.debug(f"LogActor {aid} initialized") + self.aid = aid + self.actor = ActorInActor.remote(f"actor_{aid}", log_level) + ray.get(self.actor.__ray_ready__.remote()) + + def log_info(self, message: str): + self.logger.info(f"LogActor {self.aid} info: {message}") + self.logger.debug(f"LogActor {self.aid} debug: {message}") + ModuleInActor() + ModuleInActor2() + + +class LogTest(unittest.TestCase): + def setUp(self): + if ray.is_initialized(): + ray.shutdown() + self.config = get_template_config() + self.config.check_and_update() + self.log_dir = self.config.log.save_dir + shutil.rmtree(self.log_dir, ignore_errors=True) + os.makedirs(self.log_dir, exist_ok=True) + + def test_no_actor_log(self): + ray.init( + namespace=self.config.ray_namespace, + runtime_env=RuntimeEnv( + env_vars={LOG_DIR_ENV_VAR: self.log_dir, LOG_LEVEL_ENV_VAR: "INFO"} + ), + ) + try: + logger = get_logger("outside_actor", level=logging.DEBUG) + logger.info("Outside logger initialized") + logger.debug("Outside logger initialized") + self.assertFalse(os.path.exists(os.path.join(self.log_dir, "outside_actor.log"))) + + logger = get_logger( + "outside_actor", in_ray_actor=True + ) # in_ray_actor should not take effect + logger.info("Outside logger initialized") + self.assertFalse(os.path.exists(os.path.join(self.log_dir, "outside_actor.log"))) + + finally: + ray.shutdown(_exiting_interpreter=True) + + def test_actor_log(self): + ray.init( + namespace=self.config.ray_namespace, + runtime_env=RuntimeEnv( + env_vars={ + LOG_DIR_ENV_VAR: self.log_dir, + LOG_LEVEL_ENV_VAR: "INFO", + } + ), + ) + try: + actor1 = LogActor.remote(1, log_level=logging.INFO) + actor2 = LogActor.remote(2, log_level=logging.DEBUG) + actor3 = LogActor.remote(3, log_level=None) + ray.get(actor1.log_info.remote("Test message")) + ray.get(actor2.log_info.remote("Test message")) + ray.get(actor3.log_info.remote("Test message")) + self.assertTrue(os.path.exists(os.path.join(self.log_dir, "actor_1.log"))) + self.assertTrue(os.path.exists(os.path.join(self.log_dir, "actor_2.log"))) + self.assertTrue(os.path.exists(os.path.join(self.log_dir, "actor_3.log"))) + self.assertTrue(os.path.exists(os.path.join(self.log_dir, "actor_1_nested.log"))) + self.assertTrue(os.path.exists(os.path.join(self.log_dir, "actor_2_nested.log"))) + self.assertTrue(os.path.exists(os.path.join(self.log_dir, "actor_3_nested.log"))) + self.assertFalse(os.path.exists(os.path.join(self.log_dir, "module_in_actor.log"))) + self.assertFalse(os.path.exists(os.path.join(self.log_dir, "module_in_actor2.log"))) + with open(os.path.join(self.log_dir, "actor_1.log"), "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 4) + with open(os.path.join(self.log_dir, "actor_2.log"), "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 8) + with open(os.path.join(self.log_dir, "actor_3.log"), "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 4) + with open(os.path.join(self.log_dir, "actor_1_nested.log"), "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 1) + with open(os.path.join(self.log_dir, "actor_2_nested.log"), "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 2) + with open(os.path.join(self.log_dir, "actor_3_nested.log"), "r") as f: + lines = f.readlines() + self.assertEqual(len(lines), 1) + finally: + ray.shutdown(_exiting_interpreter=True) + + def test_group_by_node(self): + ray.init( + namespace=self.config.ray_namespace, + runtime_env=RuntimeEnv( + env_vars={ + LOG_DIR_ENV_VAR: self.log_dir, + LOG_LEVEL_ENV_VAR: "INFO", + LOG_NODE_IP_ENV_VAR: "1", + } + ), + ) + try: + actor = LogActor.remote(1, log_level=logging.INFO) + ray.get(actor.log_info.remote("Test message")) + ips = os.listdir(self.config.log.save_dir) + self.assertTrue(len(ips) > 0) + for ip in ips: + self.assertTrue(os.path.isdir(os.path.join(self.config.log.save_dir, ip))) + ip_logs = os.listdir(os.path.join(self.config.log.save_dir, ip)) + self.assertTrue(len(ip_logs) > 0) + finally: + ray.shutdown(_exiting_interpreter=True) diff --git a/trinity/buffer/pipelines/experience_pipeline.py b/trinity/buffer/pipelines/experience_pipeline.py index d301a1cbf0..6250b46703 100644 --- a/trinity/buffer/pipelines/experience_pipeline.py +++ b/trinity/buffer/pipelines/experience_pipeline.py @@ -34,7 +34,7 @@ class ExperiencePipeline: """ def __init__(self, config: Config): - self.logger = get_logger(__name__) + self.logger = get_logger(f"{config.explorer.name}_experience_pipeline", in_ray_actor=True) load_plugins() pipeline_config = config.data_processor.experience_pipeline buffer_config = config.buffer diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 01271dad19..06e0367866 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -35,7 +35,7 @@ class DBWrapper: """ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(__name__) + self.logger = get_logger(f"sql_{storage_config.name}") if storage_config.path is None: storage_config.path = default_storage_path(storage_config, config) self.engine = create_engine(storage_config.path, poolclass=NullPool) @@ -220,7 +220,7 @@ class QueueWrapper: """An wrapper of a async queue.""" def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(__name__) + self.logger = get_logger(f"queue_{storage_config.name}") self.config = config self.capacity = storage_config.capacity self.queue = QueueBuffer.get_queue(storage_config, config) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index 3745730f22..53c9fa758c 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -8,9 +8,6 @@ from trinity.buffer.ray_wrapper import QueueWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) class QueueReader(BufferReader): diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 4b8034d716..27dc6df28a 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -7,9 +7,6 @@ from trinity.buffer.ray_wrapper import QueueWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType -from trinity.utils.log import get_logger - -logger = get_logger(__name__) class QueueWriter(BufferWriter): diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index 89c1456728..de7719de40 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -10,7 +10,12 @@ from trinity.buffer.pipelines.task_pipeline import check_and_run_task_pipeline from trinity.common.config import Config, load_config -from trinity.common.constants import PLUGIN_DIRS_ENV_VAR +from trinity.common.constants import ( + LOG_DIR_ENV_VAR, + LOG_LEVEL_ENV_VAR, + LOG_NODE_IP_ENV_VAR, + PLUGIN_DIRS_ENV_VAR, +) from trinity.explorer.explorer import Explorer from trinity.trainer.trainer import Trainer from trinity.utils.log import get_logger @@ -117,15 +122,17 @@ def both(config: Config) -> None: logger.error(f"Explorer or Trainer failed:\n{traceback.format_exc()}") -def run(config_path: str, dlc: bool = False, plugin_dir: str = None): +def run(config_path: str, log_level: str = "INFO", dlc: bool = False, plugin_dir: str = None): config = load_config(config_path) config.check_and_update() pprint(config) - # try to run task pipeline for raw data - check_and_run_task_pipeline(config) - - envs = {PLUGIN_DIRS_ENV_VAR: plugin_dir or ""} + envs = { + PLUGIN_DIRS_ENV_VAR: plugin_dir or "", + LOG_DIR_ENV_VAR: config.log.save_dir, + LOG_LEVEL_ENV_VAR: config.log.level, + LOG_NODE_IP_ENV_VAR: "1" if config.log.group_by_node else "0", + } if dlc: from trinity.utils.dlc_utils import setup_ray_cluster @@ -138,6 +145,10 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): ray.init( namespace=config.ray_namespace, ignore_reinit_error=True, runtime_env={"env_vars": envs} ) + + # try to run task pipeline for raw data + check_and_run_task_pipeline(config) + try: if config.mode == "explore": explore(config) @@ -205,7 +216,7 @@ def main() -> None: args = parser.parse_args() if args.command == "run": # TODO: support parse all args from command line - run(args.config, args.dlc, args.plugin_dir) + run(args.config, args.log_level, args.dlc, args.plugin_dir) elif args.command == "studio": studio(args.port) diff --git a/trinity/common/config.py b/trinity/common/config.py index 122e0675d2..7502faa48c 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -478,6 +478,16 @@ class ServiceConfig: data_juicer: Optional[DataJuicerServiceConfig] = None +@dataclass +class LogConfig: + """Configs for logger.""" + + level: str = "INFO" # default log level (DEBUG, INFO, WARNING, ERROR) + group_by_node: bool = False # whether to group logs by node IP in Ray cluster + # ! DO NOT SET, automatically generated as ///log + save_dir: str = "" + + @dataclass class Config: """Global Configuration""" @@ -505,6 +515,7 @@ class Config: monitor: MonitorConfig = field(default_factory=MonitorConfig) synchronizer: SynchronizerConfig = field(default_factory=SynchronizerConfig) service: ServiceConfig = field(default_factory=ServiceConfig) + log: LogConfig = field(default_factory=LogConfig) def save(self, config_path: str) -> None: """Save config to file.""" @@ -894,6 +905,9 @@ def check_and_update(self) -> None: # noqa: C901 if operator.name == "data_juicer": operator.args["service_config"] = self.service.data_juicer + # check log + self.log.save_dir = os.path.join(self.checkpoint_job_dir, "log") + def flatten(self) -> Dict[str, Any]: """Flatten the config into a single-level dict with dot-separated keys for nested fields.""" diff --git a/trinity/common/constants.py b/trinity/common/constants.py index 3faacdc3f1..ccb75109f7 100644 --- a/trinity/common/constants.py +++ b/trinity/common/constants.py @@ -3,9 +3,6 @@ from enum import Enum, EnumMeta from trinity.utils.annotations import Deprecated -from trinity.utils.log import get_logger - -logger = get_logger(__name__) # names @@ -14,7 +11,11 @@ ROLLOUT_WEIGHT_SYNC_GROUP_NAME = "rollout_weight_sync" +# trinity env var names PLUGIN_DIRS_ENV_VAR = "TRINITY_PLUGIN_DIRS" +LOG_DIR_ENV_VAR = "TRINITY_LOG_DIR" # log dir +LOG_LEVEL_ENV_VAR = "TRINITY_LOG_LEVEL" # global log level +LOG_NODE_IP_ENV_VAR = "TRINITY_LOG_NODE_IP" # whether to organize logs by node IP # constants @@ -87,15 +88,12 @@ class MonitorType(CaseInsensitiveEnum): class SyncMethodEnumMeta(CaseInsensitiveEnumMeta): def __call__(cls, value, *args, **kwargs): if value == "online": - logger.warning("SyncMethod `online` is deprecated, use `nccl` instead.") value = "nccl" elif value == "offline": - logger.warning("SyncMethod `offline` is deprecated, use `checkpoint` instead.") value = "checkpoint" try: return super().__call__(value, *args, **kwargs) - except Exception as e: - logger.warning("Error parsing SyncMethod:", e) + except Exception: raise ValueError(f"Invalid SyncMethod: {value}") diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index fb85590b06..964e90684c 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -10,7 +10,7 @@ class _BundleAllocator: """An allocator for bundles.""" def __init__(self, node_bundle_map: dict[str, list]) -> None: - self.logger = get_logger(__name__) + self.logger = get_logger(__name__, in_ray_actor=True) self.node_bundle_list = [value for value in node_bundle_map.values()] self.node_list = [key for key in node_bundle_map.keys()] self.nid = 0 diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 67e5b59504..49760739eb 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -9,8 +9,6 @@ from trinity.utils.log import get_logger -logger = get_logger(__name__) - def tokenize_and_mask_messages_hf( tokenizer: Any, @@ -187,6 +185,7 @@ def get_verl_checkpoint_info( # copy from verl/scripts/model_merger.py def load_state_dict_from_verl_checkpoint(checkpoint_path: str) -> dict: # noqa: C901 """Load state dict from a Verl checkpoint.""" + logger = get_logger(__name__) logger.info(f"Loading state dict from {checkpoint_path}") assert not checkpoint_path.endswith( "huggingface" diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 7e55c6c85a..a116585412 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -20,8 +20,6 @@ ) from trinity.utils.log import get_logger -logger = get_logger(__name__) - # TODO: remove V0 when V1 is stable class vLLMRolloutModel(InferenceModel): @@ -242,7 +240,7 @@ def shutdown(self): and they won't be able to be tracked by Ray anymore. """ if hasattr(self.async_llm, "shutdown"): - logger.info("Shutting down vLLM engine") + self.logger.info("Shutting down vLLM engine") self.async_llm.shutdown() def _create_sampling_params(self, **kwargs): diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 57bc90d102..a312b38569 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -8,8 +8,6 @@ from trinity.utils.distributed import init_process_group from trinity.utils.log import get_logger -logger = get_logger(__name__) - class WorkerExtension: def init_process_group( @@ -26,14 +24,17 @@ def init_process_group( namespace: str = None, ): """Init torch process group for model weights update""" + rank = torch.distributed.get_rank() + self.logger = get_logger(f"vllm_worker_{rank}") + assert torch.distributed.is_initialized(), "default torch process group must be initialized" assert group_name != "", "group name must not be empty" self._state_dict_meta = state_dict_meta - self._weight_update_rank = torch.distributed.get_rank() + rank_offset - logger.info( + self._weight_update_rank = rank + rank_offset + self.logger.info( f"vLLM starting init_process_group:\n" f" > address={master_address}:{master_port}\n" - f" > rank={torch.distributed.get_rank()}\n" + f" > rank={rank}\n" f" > rank_offset={rank_offset}\n" f" > world_size={world_size}" ) @@ -48,7 +49,7 @@ def init_process_group( device_id=self.device, ) torch.distributed.barrier(group=self._model_update_group) - logger.info("vLLM init_process_group finished.") + self.logger.info("vLLM init_process_group finished.") self._explorer_name = explorer_name self._namespace = namespace self.synchronizer = Synchronizer.get_actor(namespace=self._namespace) diff --git a/trinity/common/rewards/countdown_reward.py b/trinity/common/rewards/countdown_reward.py index 07417431ad..a267b629f6 100644 --- a/trinity/common/rewards/countdown_reward.py +++ b/trinity/common/rewards/countdown_reward.py @@ -8,9 +8,6 @@ extract_solution, validate_equation, ) -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @REWARD_FUNCTIONS.register_module("countdown_reward") diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index a527bf613a..a6b6023aa2 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -6,9 +6,6 @@ from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn from trinity.utils.eval_utils import compute_score -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @REWARD_FUNCTIONS.register_module("math_dapo_reward") diff --git a/trinity/common/rewards/format_reward.py b/trinity/common/rewards/format_reward.py index 0ffe8637ec..ad9b203f82 100644 --- a/trinity/common/rewards/format_reward.py +++ b/trinity/common/rewards/format_reward.py @@ -4,9 +4,6 @@ from typing import Optional from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @REWARD_FUNCTIONS.register_module("format_reward") diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py index 9e1c8abed6..4772a54263 100644 --- a/trinity/common/rewards/math_reward.py +++ b/trinity/common/rewards/math_reward.py @@ -10,9 +10,6 @@ simple_answer_parser, validate_think_pattern, ) -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @REWARD_FUNCTIONS.register_module("math_reward") diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index f2c58b70a5..bcf9f97f8a 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -5,12 +5,8 @@ from trinity.common.experience import Experience from trinity.common.rewards.utils import to_rm_gallery_messages -from trinity.utils.log import get_logger from trinity.utils.registry import Registry -logger = get_logger(__name__) - - REWARD_FUNCTIONS = Registry("reward_functions") diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index aee3effd11..4fdfbf2a21 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -6,9 +6,6 @@ from trinity.common.experience import Experience from trinity.common.rewards.math_reward import MathBoxedRewardFn from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @WORKFLOWS.register_module("math_boxed_workflow") @@ -63,7 +60,7 @@ def run(self) -> List[Experience]: else: prompt_text = self.format_prompt() - logger.debug("start chat") + self.logger.debug("start chat") if not self.use_base: responses = self.model.chat(messages, **self.rollout_args) else: @@ -86,11 +83,11 @@ def run(self) -> List[Experience]: response.eid.run = i + self.run_id_base if not self.use_base: - logger.debug( + self.logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) else: - logger.debug( + self.logger.debug( f"self.task_desc: {self.task_desc}, prompt_text: {prompt_text}, response: {response.response_text}, reward: {reward}" ) return responses diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index bbb22cc893..ed5dc98771 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -12,9 +12,6 @@ from trinity.common.experience import Experience from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task -from trinity.utils.log import get_logger - -logger = get_logger(__name__) # Adapted from https://github.com/NVlabs/Tool-N1 qwen_tool_prompts = """# Tool @@ -238,7 +235,7 @@ def format_prompt(self): def run(self) -> List[Experience]: messages = self.format_prompt() - logger.debug("start chat") + self.logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) for i, response in enumerate(responses): @@ -252,12 +249,12 @@ def run(self) -> List[Experience]: ground_truth=ground_truth, ) else: - logger.error( + self.logger.error( "Key 'answer' not found in self.raw_task. Assigning default reward." ) else: - logger.error("self.raw_task is None. Assigning default reward.") - logger.debug( + self.logger.error("self.raw_task is None. Assigning default reward.") + self.logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) response.reward = reward diff --git a/trinity/common/workflows/envs/agentscope/agentscope_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscope_react_workflow.py index 879313e664..535402f2a9 100644 --- a/trinity/common/workflows/envs/agentscope/agentscope_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscope_react_workflow.py @@ -8,9 +8,6 @@ from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathBoxedRewardFn from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @WORKFLOWS.register_module("agentscope_reactv2_math_workflow") @@ -32,7 +29,7 @@ def __init__( from agentscope.service import ServiceToolkit, execute_python_code except ImportError as e: error_message = f"AgentScope is not installed. Please install the agentscope framework first before running the workflow. Error: {str(e)}" - logger.error(error_message) + self.logger.error(error_message) raise ImportError(error_message) # get openai client from model @@ -84,7 +81,7 @@ def reset(self, task: Task): from agentscope.agents import ReActAgentV2 except ImportError as e: error_message = f"AgentScope is not installed. Please install the agentscope framework first before running the workflow. Error: {str(e)}" - logger.error(error_message) + self.logger.error(error_message) raise ImportError(error_message) self.agent = ReActAgentV2( name="math_react_agent", @@ -109,7 +106,7 @@ def reset(self, task: Task): else: self.answer = str(self.truth) except Exception as e: - logger.debug(f"Error in getting answer from truth: {str(e)}") + self.logger.debug(f"Error in getting answer from truth: {str(e)}") self.answer = str(self.truth) # we use the boxed format to evaluate the answer @@ -125,7 +122,7 @@ def run(self): from agentscope.message import Msg except ImportError as e: error_message = f"AgentScope is not installed. Please install the agentscope framework first before running the workflow. Error: {str(e)}" - logger.error(error_message) + self.logger.error(error_message) raise ImportError(error_message) # provide the task to the react agent @@ -141,14 +138,14 @@ def run(self): response_text = content except Exception as e: error_message = f"Error in processing the response: {e}" - logger.info(error_message) + self.logger.info(error_message) response_text = str(content) reward = self.reward_fn(response_text, self.answer) reward = sum(reward.values()) - logger.debug(f"Reward: {reward}") + self.logger.debug(f"Reward: {reward}") experiences = self.model.extract_experience_from_history(clear_history=True) - logger.debug(f"Experiences extracted len: {len(experiences)}") + self.logger.debug(f"Experiences extracted len: {len(experiences)}") for i, experience in enumerate(experiences): experience.eid.step = i experience.reward = reward @@ -156,7 +153,7 @@ def run(self): if experience.metrics is None: experience.metrics = {} experience.metrics.update(turns_metrics) - logger.debug( + self.logger.debug( f"return experience len: {len(experiences)}, run_id: {str(experiences[-1].eid.run)}, final step reward: {experiences[-1].reward}" ) return experiences diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 0341072de9..2d97a5325b 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -10,11 +10,8 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow -from trinity.utils.log import get_logger from trinity.utils.math_eval_utils import verify_math_answer -logger = get_logger(__name__) - @WORKFLOWS.register_module("math_eval_workflow") class MathEvalWorkflow(Workflow): diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index 3e4f9a9e13..b46f5b7713 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -8,9 +8,6 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task -from trinity.utils.log import get_logger - -logger = get_logger(__name__) @WORKFLOWS.register_module("math_rm_workflow") @@ -34,7 +31,7 @@ def __init__( def run(self) -> List[Experience]: messages = self.format_messages() - logger.debug("start chat") + self.logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) for i, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore @@ -50,7 +47,7 @@ def run(self) -> List[Experience]: response.reward = reward response.eid.run = i + self.run_id_base - logger.debug( + self.logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) return responses diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 20e03c9271..9a1da4ca49 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -17,9 +17,6 @@ from trinity.utils.log import get_logger from trinity.utils.registry import Registry -logger = get_logger(__name__) - - WORKFLOWS = Registry("workflows") @@ -95,6 +92,7 @@ def __init__( self.model = model self.auxiliary_models = auxiliary_models self.run_id_base = 0 + self.logger = get_logger(__name__) @property def resettable(self): @@ -240,7 +238,7 @@ def run(self) -> List[Experience]: # TODO: Optimize the generate function messages = self.format_messages() - logger.debug("start chat") + self.logger.debug("start chat") responses = self.model.chat(messages, **self.rollout_args) for i, response in enumerate(responses): reward_dict = self.reward_fn( # type: ignore [misc] @@ -255,7 +253,7 @@ def run(self) -> List[Experience]: response.reward = reward response.eid.run = i + self.run_id_base - logger.debug( + self.logger.debug( f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" ) return responses diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index c3c8c40cb6..b78856442e 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -36,7 +36,7 @@ class Explorer: """Responsible for exploring the taskset.""" def __init__(self, config: Config): - self.logger = get_logger(__name__) + self.logger = get_logger(config.explorer.name, in_ray_actor=True) load_plugins() self.cache = CacheManager(config) explorer_meta = self.cache.load_explorer() diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index ba810f041a..967ae93153 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -32,7 +32,7 @@ def __init__( auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: - self.logger = get_logger(__name__) + self.logger = get_logger(f"{config.explorer.name}_runner_{runner_id}", in_ray_actor=True) self.config = config self.model = model self.model_wrapper = ModelWrapper( diff --git a/trinity/manager/manager.py b/trinity/manager/manager.py index 4af6f28685..f65f5af204 100644 --- a/trinity/manager/manager.py +++ b/trinity/manager/manager.py @@ -6,13 +6,12 @@ from trinity.common.config import Config, load_config from trinity.utils.log import get_logger -logger = get_logger(__name__) - class CacheManager: """A Manager class for managing the cache dir.""" def __init__(self, config: Config, check_config: bool = False): + self.logger = get_logger(__name__, in_ray_actor=True) self.cache_dir = config.monitor.cache_dir # type: ignore self.explorer_meta_path = os.path.join(self.cache_dir, f"{config.explorer.name}_meta.json") # type: ignore self.trainer_meta_path = os.path.join(self.cache_dir, f"{config.trainer.name}_meta.json") # type: ignore @@ -27,7 +26,7 @@ def _check_config_consistency(self, config: Config) -> None: else: backup_config = load_config(backup_config_path) if backup_config != config: - logger.warning( + self.logger.warning( f"The current config is inconsistent with the backup config in {backup_config_path}." ) raise ValueError( @@ -47,7 +46,7 @@ def load_explorer(self) -> dict: try: with open(self.explorer_meta_path, "r", encoding="utf-8") as f: explorer_meta = json.load(f) - logger.info( + self.logger.info( "----------------------------------\n" "Found existing explorer checkpoint:\n" f" > {explorer_meta}\n" @@ -56,7 +55,7 @@ def load_explorer(self) -> dict: ) return explorer_meta except Exception as e: - logger.error(f"Failed to load explore meta file: {e}") + self.logger.error(f"Failed to load explore meta file: {e}") return {} def save_trainer(self, current_step: int) -> None: @@ -68,7 +67,7 @@ def load_trainer(self) -> dict: try: with open(self.trainer_meta_path, "r", encoding="utf-8") as f: trainer_meta = json.load(f) - logger.info( + self.logger.info( "----------------------------------\n" "Found existing trainer checkpoint:\n" f" > {trainer_meta}\n" @@ -77,5 +76,5 @@ def load_trainer(self) -> dict: ) return trainer_meta except Exception as e: - logger.warning(f"Failed to load trainer meta file: {e}") + self.logger.warning(f"Failed to load trainer meta file: {e}") return {} diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 519b96083d..3bf16aa530 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -31,7 +31,7 @@ class Synchronizer: """ def __init__(self, config: Config, module_ref: ray.actor.ActorHandle): - self.logger = get_logger(__name__) + self.logger = get_logger("synchronizer", in_ray_actor=True) self.config = config self.trainer_status = RunningStatus.STOPPED self.explorer_status_counts: Dict[RunningStatus, int] = defaultdict(lambda: 0) diff --git a/trinity/service/data_juicer/client.py b/trinity/service/data_juicer/client.py index ee0ea91838..44238eab60 100644 --- a/trinity/service/data_juicer/client.py +++ b/trinity/service/data_juicer/client.py @@ -18,7 +18,7 @@ class DataJuicerClient: """Client for interacting with the DataJuicer server.""" def __init__(self, config: DataJuicerServiceConfig): - self.logger = get_logger(__name__) + self.logger = get_logger(__name__, in_ray_actor=True) self.config = config self.url = config.server_url self.session_id = None diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 2d6f123a19..9626818525 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -28,7 +28,7 @@ class Trainer: def __init__(self, config: Config) -> None: self.config = config - self.logger = get_logger(__name__) + self.logger = get_logger(config.trainer.name, in_ray_actor=True) load_plugins() self.synchronizer = Synchronizer.get_actor(config) self.engine = get_trainer_wrapper(config) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 533417265e..113667dedc 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -144,7 +144,7 @@ def __init__( ray_worker_group_cls, ) self.init_workers() - self.logger = get_logger(__name__) + self.logger = get_logger(__name__, in_ray_actor=True) self.last_full_save_step = None def _validate_config(self): # TODO diff --git a/trinity/utils/log.py b/trinity/utils/log.py index d1d50d9ad8..0b3f2b4bd6 100644 --- a/trinity/utils/log.py +++ b/trinity/utils/log.py @@ -1,65 +1,100 @@ -# Adapted from -# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py -"""Logging configuration for vLLM.""" +"""A Ray compatible logging module with actor-scope logger support.""" +import contextvars import logging +import os import sys +from logging.handlers import RotatingFileHandler +from typing import Optional -_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" -_DATE_FORMAT = "%m-%d %H:%M:%S" +import ray +from trinity.common.constants import ( + LOG_DIR_ENV_VAR, + LOG_LEVEL_ENV_VAR, + LOG_NODE_IP_ENV_VAR, +) -class NewLineFormatter(logging.Formatter): - """Adds logging prefix to newlines to align multi-line messages.""" - - def __init__(self, fmt, datefmt=None): - logging.Formatter.__init__(self, fmt, datefmt) - - def format(self, record): - msg = logging.Formatter.format(self, record) - if record.message != "": - parts = msg.split(record.message) - msg = msg.replace("\n", "\r\n" + parts[0]) - return msg +_LOG_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" +_LOG_DATE_FORMAT = "%m-%d %H:%M:%S" -_root_logger = logging.getLogger("trinity") -_default_handler = None +class NewLineFormatter(logging.Formatter): + """ + Formatter that adds logging prefix to newlines to align multi-line messages. + """ + def __init__(self, fmt: str, datefmt: Optional[str] = None): + super().__init__(fmt, datefmt) -def _setup_logger(): - _root_logger.setLevel(logging.DEBUG) - global _default_handler - if _default_handler is None: - _default_handler = logging.StreamHandler(sys.stdout) - _default_handler.flush = sys.stdout.flush # type: ignore - _default_handler.setLevel(logging.INFO) - _root_logger.addHandler(_default_handler) - fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) - _default_handler.setFormatter(fmt) - # Setting this will avoid the message - # being propagated to the parent logger. - _root_logger.propagate = False + def format(self, record: logging.LogRecord) -> str: + msg = super().format(record) + if record.message: + prefix = msg.split(record.message)[0] + msg = msg.replace("\n", f"\r\n{prefix}") + return msg -# The logger is initialized when the module is imported. -# This is thread-safe as the module is only imported once, -# guaranteed by the Python GIL. -_setup_logger() +_ray_logger_ctx: contextvars.ContextVar[Optional[logging.Logger]] = contextvars.ContextVar( + "ray_logger", default=None +) -def get_logger(name: str, level: int = logging.DEBUG) -> logging.Logger: - """Get a logger with the given name and level. +def get_logger( + name: Optional[str] = None, level: Optional[int] = None, in_ray_actor: bool = False +) -> logging.Logger: + """ + Get a logger instance, compatible with Ray Actor and standard usage. Args: - name (str): The name of the logger. - level (int, optional): The level of the logger. Defaults to logging.DEBUG. + name (Optional[str]): The name of the logger. If None, uses 'trinity'. + level (Optional[int]): The logging level. If None, uses LOG_LEVEL_ENV_VAR or INFO. + in_ray_actor (bool): Whether the logger is used within a Ray actor. Returns: - logging.Logger: The logger with the given name and level. + logging.Logger: Configured logger instance. """ - logger = logging.getLogger(name) - logger.setLevel(level) - if not logger.handlers: - logger.addHandler(_default_handler) # type: ignore [arg-type] + # Reuse logger created by the actor if exists (Ray context) + logger = _ray_logger_ctx.get() + if logger is not None: + return logger + + resolved_level = ( + level + if level is not None + else getattr(logging, os.environ.get(LOG_LEVEL_ENV_VAR, "INFO").upper()) + ) + logger_name = f"trinity.{name}" if name else "trinity" + logger = logging.getLogger(logger_name) logger.propagate = False + logger.setLevel(resolved_level) + logger.handlers.clear() + + # Stream handler (stdout) + stream_handler = logging.StreamHandler(sys.stdout) + stream_handler.setLevel(resolved_level) + formatter = NewLineFormatter(_LOG_FORMAT, datefmt=_LOG_DATE_FORMAT) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + + if in_ray_actor: + # File handler (rotating file log) + log_dir = os.environ.get(LOG_DIR_ENV_VAR) + assert name is not None, "Logger name must be set when logging from a Ray actor" + if log_dir: + if os.environ.get(LOG_NODE_IP_ENV_VAR, "0") != "0": + # organize logs by node IP + node_ip = ray.util.get_node_ip_address() + log_dir = os.path.join(log_dir, node_ip) + os.makedirs(log_dir, exist_ok=True) + # save log into log_dir/{actor_name}.log + file_path = os.path.join(log_dir, f"{name}.log") + file_handler = RotatingFileHandler( + file_path, encoding="utf-8", maxBytes=64 * 1024 * 1024 + ) + file_handler.setLevel(resolved_level) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + _ray_logger_ctx.set(logger) + # If LOG_DIR_ENV_VAR is not set, file logging is disabled + return logger diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 3a596d7b9f..5f9961c7b1 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -100,7 +100,7 @@ def __init__( self.tensorboard_dir = os.path.join(config.monitor.cache_dir, "tensorboard", role) os.makedirs(self.tensorboard_dir, exist_ok=True) self.logger = SummaryWriter(self.tensorboard_dir) - self.console_logger = get_logger(__name__) + self.console_logger = get_logger(__name__, in_ray_actor=True) def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): pass @@ -143,7 +143,7 @@ def __init__( config=config, save_code=False, ) - self.console_logger = get_logger(__name__) + self.console_logger = get_logger(__name__, in_ray_actor=True) def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): experiences_table = wandb.Table(dataframe=experiences_table) @@ -197,7 +197,7 @@ def __init__( }, ) mlflow.log_params(config.flatten()) - self.console_logger = get_logger(__name__) + self.console_logger = get_logger(__name__, in_ray_actor=True) def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): pass diff --git a/trinity/utils/plugin_loader.py b/trinity/utils/plugin_loader.py index 203aa918c2..c3d956f2b1 100644 --- a/trinity/utils/plugin_loader.py +++ b/trinity/utils/plugin_loader.py @@ -10,8 +10,6 @@ from trinity.common.constants import PLUGIN_DIRS_ENV_VAR from trinity.utils.log import get_logger -logger = get_logger(__name__) - def load_plugins() -> None: """ @@ -29,6 +27,7 @@ def load_plugin_from_dirs(plugin_dirs: Union[str, List[str]]) -> None: """ Load plugin modules from a directory. """ + logger = get_logger(__name__, in_ray_actor=True) if not isinstance(plugin_dirs, list): plugin_dirs = [plugin_dirs] plugin_dirs = set(plugin_dirs) @@ -40,7 +39,6 @@ def load_plugin_from_dirs(plugin_dirs: Union[str, List[str]]) -> None: logger.error(f"plugin-dir [{plugin_dir}] is not a directory.") continue - logger.info(f"Loading plugin modules from [{plugin_dir}]...") for file in Path(plugin_dir).glob("*.py"): if file.name.startswith("__"): continue @@ -83,5 +81,4 @@ def load_from_file(file_path: str): shutil.copy2(file_path, Path(__file__).parent.parent / "plugins") except shutil.SameFileError: pass - logger.info(f"Load {file_path} as {full_module_name}") return module diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index 371d7a4346..e5f6806378 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -1,9 +1,5 @@ from typing import Any, Type -from trinity.utils.log import get_logger - -logger = get_logger(__name__) - # TODO: support lazy load # e.g. @MODULES.register_module("name", lazy=True) @@ -38,11 +34,6 @@ def modules(self) -> dict: """ return self._modules - def list(self) -> None: - """Logging the list of module in current registry.""" - for m in self._modules.keys(): - logger.info(f"{self._name}\t{m}") - def get(self, module_key) -> Any: """ Get module named module_key from in current registry. If not found, @@ -99,7 +90,6 @@ class MyWorkflow(Workflow): module_cls=MyWorkflow, force=True, ) - """ if not (module_name is None or isinstance(module_name, str)): raise TypeError(f"module_name must be either of None, str," f"got {type(module_name)}")