Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,17 @@ jobs:
- name: Check ray status
working-directory: trinity-${{ github.run_id }}/.github/workflows/docker
run: |
docker compose exec trinity-node-1 ray status
docker compose exec trinity-node-2 ray status
MAX_RETRIES=20
RETRY_INTERVAL=5
for i in $(seq 1 $MAX_RETRIES); do
docker compose exec trinity-node-1 ray status && docker compose exec trinity-node-2 ray status && break
echo "Waiting for ray cluster to be ready... ($i/$MAX_RETRIES)"
sleep $RETRY_INTERVAL
if [ "$i" -eq "$MAX_RETRIES" ]; then
echo "Ray cluster failed to start after $MAX_RETRIES retries."
exit 1
fi
done

- name: Decide test type
id: test_type
Expand Down Expand Up @@ -89,15 +98,15 @@ jobs:
fi

- name: Upload test results
if: env.tests_run == 'true'
if: env.tests_run == 'true' || failure()
uses: actions/upload-artifact@v4
with:
name: pytest-results
path: trinity-${{ github.run_id }}/report.json
continue-on-error: true

- name: Publish Test Report
if: env.tests_run == 'true'
if: env.tests_run == 'true' || failure()
uses: ctrf-io/github-test-reporter@v1
with:
report-path: trinity-${{ github.run_id }}/report.json
Expand Down
8 changes: 8 additions & 0 deletions tests/common/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def test_continue_from_checkpoint_is_valid(self):
timestamp = config.name.split("_")[-1]
self.assertTrue(datetime.datetime.strptime(timestamp, "%Y%m%d%H%M%S"))

def test_config_flatten(self):
config = get_template_config()
flat_config = config.flatten()
self.assertIsInstance(flat_config, dict)
for key, value in flat_config.items():
self.assertIsInstance(key, str)
self.assertNotIsInstance(value, dict)

def tearDown(self):
if os.path.exists(CHECKPOINT_ROOT_DIR):
shutil.rmtree(CHECKPOINT_ROOT_DIR)
34 changes: 33 additions & 1 deletion trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional

from omegaconf import OmegaConf
Expand Down Expand Up @@ -386,7 +387,7 @@ class MonitorConfig:
# TODO: support multiple monitors (List[str])
monitor_type: str = "tensorboard"
# the default args for monitor
monitor_args: Dict = field(default_factory=dict)
monitor_args: Optional[Dict] = None
# whether to enable ray timeline profile
# the output file will be saved to `cache_dir/timeline.json`
enable_ray_timeline: bool = False
Expand Down Expand Up @@ -793,6 +794,14 @@ def check_and_update(self) -> None: # noqa: C901

self._check_interval()

# check monitor
from trinity.utils.monitor import MONITOR

monitor_cls = MONITOR.get(self.monitor.monitor_type)
if monitor_cls is None:
raise ValueError(f"Invalid monitor type: {self.monitor.monitor_type}")
if self.monitor.monitor_args is None:
self.monitor.monitor_args = monitor_cls.default_args()
# create a job dir in <checkpoint_root_dir>/<project>/<name>/monitor
self.monitor.cache_dir = os.path.join(self.checkpoint_job_dir, "monitor")
try:
Expand Down Expand Up @@ -831,6 +840,29 @@ def check_and_update(self) -> None: # noqa: C901
else:
self.trainer.trainer_config = None

def flatten(self) -> Dict[str, Any]:
"""Flatten the config into a single-level dict with dot-separated keys for nested fields."""

def _flatten(obj, parent_key="", sep="."):
items = {}
if hasattr(obj, "__dataclass_fields__"):
obj = vars(obj)
if isinstance(obj, dict):
for k, v in obj.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
items.update(_flatten(v, new_key, sep=sep))
elif isinstance(obj, list):
for i, v in enumerate(obj):
new_key = f"{parent_key}{sep}{i}" if parent_key else str(i)
items.update(_flatten(v, new_key, sep=sep))
elif isinstance(obj, Enum):
items[parent_key] = obj.value
else:
items[parent_key] = obj
return items

return _flatten(self)


def load_config(config_path: str) -> Config:
"""Load the configuration from the given path."""
Expand Down
91 changes: 90 additions & 1 deletion trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@

import numpy as np
import pandas as pd
import wandb

try:
import wandb
except ImportError:
wandb = None

try:
import mlflow
except ImportError:
mlflow = None
from torch.utils.tensorboard import SummaryWriter

from trinity.common.config import Config
Expand Down Expand Up @@ -77,6 +86,11 @@ def calculate_metrics(
metrics[key] = val
return metrics

@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {}


@MONITOR.register_module("tensorboard")
class TensorboardMonitor(Monitor):
Expand All @@ -103,11 +117,24 @@ def close(self) -> None:

@MONITOR.register_module("wandb")
class WandbMonitor(Monitor):
"""Monitor with Weights & Biases.

Args:
base_url (`Optional[str]`): The base URL of the W&B server. If not provided, use the environment variable `WANDB_BASE_URL`.
api_key (`Optional[str]`): The API key for W&B. If not provided, use the environment variable `WANDB_API_KEY`.
"""

def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert wandb is not None, "wandb is not installed. Please install it to use WandbMonitor."
if not group:
group = name
monitor_args = config.monitor.monitor_args or {}
if base_url := monitor_args.get("base_url"):
os.environ["WANDB_BASE_URL"] = base_url
if api_key := monitor_args.get("api_key"):
os.environ["WANDB_API_KEY"] = api_key
self.logger = wandb.init(
project=project,
group=group,
Expand All @@ -129,3 +156,65 @@ def log(self, data: dict, step: int, commit: bool = False) -> None:

def close(self) -> None:
self.logger.finish()

@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {
"base_url": None,
"api_key": None,
}


@MONITOR.register_module("mlflow")
class MlflowMonitor(Monitor):
"""Monitor with MLflow.

Args:
uri (`Optional[str]`): The tracking server URI. If not provided, the default is `http://localhost:5000`.
username (`Optional[str]`): The username to login. If not provided, the default is `None`.
password (`Optional[str]`): The password to login. If not provided, the default is `None`.
"""

def __init__(
self, project: str, group: str, name: str, role: str, config: Config = None
) -> None:
assert (
mlflow is not None
), "mlflow is not installed. Please install it to use MlflowMonitor."
monitor_args = config.monitor.monitor_args or {}
if username := monitor_args.get("username"):
os.environ["MLFLOW_TRACKING_USERNAME"] = username
if password := monitor_args.get("password"):
os.environ["MLFLOW_TRACKING_PASSWORD"] = password
mlflow.set_tracking_uri(config.monitor.monitor_args.get("uri", "http://localhost:5000"))
mlflow.set_experiment(project)
mlflow.start_run(
run_name=f"{name}_{role}",
tags={
"group": group,
"role": role,
},
)
mlflow.log_params(config.flatten())
self.console_logger = get_logger(__name__)

def log_table(self, table_name: str, experiences_table: pd.DataFrame, step: int):
pass

def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
mlflow.log_metrics(metrics=data, step=step)
self.console_logger.info(f"Step {step}: {data}")

def close(self) -> None:
mlflow.end_run()

@classmethod
def default_args(cls) -> Dict:
"""Return default arguments for the monitor."""
return {
"uri": "http://localhost:5000",
"username": None,
"password": None,
}