diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 41f7e226ab..a26174b863 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -31,8 +31,17 @@ jobs: - name: Check ray status working-directory: trinity-${{ github.run_id }}/.github/workflows/docker run: | - docker compose exec trinity-node-1 ray status - docker compose exec trinity-node-2 ray status + MAX_RETRIES=20 + RETRY_INTERVAL=5 + for i in $(seq 1 $MAX_RETRIES); do + docker compose exec trinity-node-1 ray status && docker compose exec trinity-node-2 ray status && break + echo "Waiting for ray cluster to be ready... ($i/$MAX_RETRIES)" + sleep $RETRY_INTERVAL + if [ "$i" -eq "$MAX_RETRIES" ]; then + echo "Ray cluster failed to start after $MAX_RETRIES retries." + exit 1 + fi + done - name: Decide test type id: test_type @@ -89,7 +98,7 @@ jobs: fi - name: Upload test results - if: env.tests_run == 'true' + if: env.tests_run == 'true' || failure() uses: actions/upload-artifact@v4 with: name: pytest-results @@ -97,7 +106,7 @@ jobs: continue-on-error: true - name: Publish Test Report - if: env.tests_run == 'true' + if: env.tests_run == 'true' || failure() uses: ctrf-io/github-test-reporter@v1 with: report-path: trinity-${{ github.run_id }}/report.json diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 994b573c04..a7b4d63530 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -74,6 +74,14 @@ def test_continue_from_checkpoint_is_valid(self): timestamp = config.name.split("_")[-1] self.assertTrue(datetime.datetime.strptime(timestamp, "%Y%m%d%H%M%S")) + def test_config_flatten(self): + config = get_template_config() + flat_config = config.flatten() + self.assertIsInstance(flat_config, dict) + for key, value in flat_config.items(): + self.assertIsInstance(key, str) + self.assertNotIsInstance(value, dict) + def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): shutil.rmtree(CHECKPOINT_ROOT_DIR) diff --git a/trinity/common/config.py b/trinity/common/config.py index 717a97b668..3b735bf4f3 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -3,6 +3,7 @@ import os from dataclasses import dataclass, field from datetime import datetime +from enum import Enum from typing import Any, Dict, List, Optional from omegaconf import OmegaConf @@ -386,7 +387,7 @@ class MonitorConfig: # TODO: support multiple monitors (List[str]) monitor_type: str = "tensorboard" # the default args for monitor - monitor_args: Dict = field(default_factory=dict) + monitor_args: Optional[Dict] = None # whether to enable ray timeline profile # the output file will be saved to `cache_dir/timeline.json` enable_ray_timeline: bool = False @@ -793,6 +794,14 @@ def check_and_update(self) -> None: # noqa: C901 self._check_interval() + # check monitor + from trinity.utils.monitor import MONITOR + + monitor_cls = MONITOR.get(self.monitor.monitor_type) + if monitor_cls is None: + raise ValueError(f"Invalid monitor type: {self.monitor.monitor_type}") + if self.monitor.monitor_args is None: + self.monitor.monitor_args = monitor_cls.default_args() # create a job dir in ///monitor self.monitor.cache_dir = os.path.join(self.checkpoint_job_dir, "monitor") try: @@ -831,6 +840,29 @@ def check_and_update(self) -> None: # noqa: C901 else: self.trainer.trainer_config = None + def flatten(self) -> Dict[str, Any]: + """Flatten the config into a single-level dict with dot-separated keys for nested fields.""" + + def _flatten(obj, parent_key="", sep="."): + items = {} + if hasattr(obj, "__dataclass_fields__"): + obj = vars(obj) + if isinstance(obj, dict): + for k, v in obj.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + items.update(_flatten(v, new_key, sep=sep)) + elif isinstance(obj, list): + for i, v in enumerate(obj): + new_key = f"{parent_key}{sep}{i}" if parent_key else str(i) + items.update(_flatten(v, new_key, sep=sep)) + elif isinstance(obj, Enum): + items[parent_key] = obj.value + else: + items[parent_key] = obj + return items + + return _flatten(self) + def load_config(config_path: str) -> Config: """Load the configuration from the given path.""" diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 6ba4d7482d..75860357eb 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -6,7 +6,16 @@ import numpy as np import pandas as pd -import wandb + +try: + import wandb +except ImportError: + wandb = None + +try: + import mlflow +except ImportError: + mlflow = None from torch.utils.tensorboard import SummaryWriter from trinity.common.config import Config @@ -77,6 +86,11 @@ def calculate_metrics( metrics[key] = val return metrics + @classmethod + def default_args(cls) -> Dict: + """Return default arguments for the monitor.""" + return {} + @MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor): @@ -103,11 +117,24 @@ def close(self) -> None: @MONITOR.register_module("wandb") class WandbMonitor(Monitor): + """Monitor with Weights & Biases. + + Args: + base_url (`Optional[str]`): The base URL of the W&B server. If not provided, use the environment variable `WANDB_BASE_URL`. + api_key (`Optional[str]`): The API key for W&B. If not provided, use the environment variable `WANDB_API_KEY`. + """ + def __init__( self, project: str, group: str, name: str, role: str, config: Config = None ) -> None: + assert wandb is not None, "wandb is not installed. Please install it to use WandbMonitor." if not group: group = name + monitor_args = config.monitor.monitor_args or {} + if base_url := monitor_args.get("base_url"): + os.environ["WANDB_BASE_URL"] = base_url + if api_key := monitor_args.get("api_key"): + os.environ["WANDB_API_KEY"] = api_key self.logger = wandb.init( project=project, group=group, @@ -129,3 +156,65 @@ def log(self, data: dict, step: int, commit: bool = False) -> None: def close(self) -> None: self.logger.finish() + + @classmethod + def default_args(cls) -> Dict: + """Return default arguments for the monitor.""" + return { + "base_url": None, + "api_key": None, + } + + +@MONITOR.register_module("mlflow") +class MlflowMonitor(Monitor): + """Monitor with MLflow. + + Args: + uri (`Optional[str]`): The tracking server URI. If not provided, the default is `http://localhost:5000`. + username (`Optional[str]`): The username to login. If not provided, the default is `None`. + password (`Optional[str]`): The password to login. If not provided, the default is `None`. + """ + + def __init__( + self, project: str, group: str, name: str, role: str, config: Config = None + ) -> None: + assert ( + mlflow is not None + ), "mlflow is not installed. Please install it to use MlflowMonitor." + monitor_args = config.monitor.monitor_args or {} + if username := monitor_args.get("username"): + os.environ["MLFLOW_TRACKING_USERNAME"] = username + if password := monitor_args.get("password"): + os.environ["MLFLOW_TRACKING_PASSWORD"] = password + mlflow.set_tracking_uri(config.monitor.monitor_args.get("uri", "http://localhost:5000")) + mlflow.set_experiment(project) + mlflow.start_run( + run_name=f"{name}_{role}", + tags={ + "group": group, + "role": role, + }, + ) + mlflow.log_params(config.flatten()) + self.console_logger = get_logger(__name__) + + def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int): + pass + + def log(self, data: dict, step: int, commit: bool = False) -> None: + """Log metrics.""" + mlflow.log_metrics(metrics=data, step=step) + self.console_logger.info(f"Step {step}: {data}") + + def close(self) -> None: + mlflow.end_run() + + @classmethod + def default_args(cls) -> Dict: + """Return default arguments for the monitor.""" + return { + "uri": "http://localhost:5000", + "username": None, + "password": None, + }