Skip to content

Commit bb04144

Browse files
KKIEEKKKIEEKnijkah
authored
Support ray logger callbacks (#31)
* Support ray logger callbacks * Refactor * Append docstring * Add example config Co-authored-by: KKIEEK <kkieek@KKIEEKui-MacBookPro.local> Co-authored-by: nijkah <nijkah@gmail.com>
1 parent 8d39aaf commit bb04144

File tree

5 files changed

+180
-1
lines changed

5 files changed

+180
-1
lines changed
+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
callbacks = [
2+
dict(
3+
type='MLflowLoggerCallback',
4+
experiment_name='mmtune',
5+
save_artifact=True,
6+
metric='train/loss',
7+
mode='max',
8+
),
9+
]

mmtune/apis/tune.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mmcv.utils import Config
66

77
from mmtune.mm.tasks import BaseTask
8+
from mmtune.ray.callbacks import build_callback
89
from mmtune.ray.schedulers import build_scheduler
910
from mmtune.ray.searchers import build_searcher
1011
from mmtune.ray.spaces import build_space
@@ -58,6 +59,10 @@ def tune(task_processor: BaseTask, tune_config: Config,
5859
if scheduler is not None:
5960
scheduler = build_scheduler(scheduler)
6061

62+
callbacks = tune_config.get('callbacks', None)
63+
if callbacks is not None:
64+
callbacks = [build_callback(callback) for callback in callbacks]
65+
6166
return ray.tune.run(
6267
trainable,
6368
name=exp_name,
@@ -70,4 +75,5 @@ def tune(task_processor: BaseTask, tune_config: Config,
7075
local_dir=tune_artifact_dir,
7176
search_alg=searcher,
7277
scheduler=scheduler,
73-
raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False))
78+
raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False),
79+
callbacks=callbacks)

mmtune/ray/callbacks/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .builder import CALLBACKS, build_callback
2+
from .mlflow import MLflowLoggerCallback
3+
4+
__all__ = ['CALLBACKS', 'build_callback', 'MLflowLoggerCallback']

mmtune/ray/callbacks/builder.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from mmcv.utils import Config, Registry
2+
from ray.tune.logger import (CSVLoggerCallback, JsonLoggerCallback,
3+
LegacyLoggerCallback, LoggerCallback,
4+
TBXLoggerCallback)
5+
6+
CALLBACKS = Registry('callbacks')
7+
CALLBACKS.register_module(module=LegacyLoggerCallback)
8+
CALLBACKS.register_module(module=JsonLoggerCallback)
9+
CALLBACKS.register_module(module=CSVLoggerCallback)
10+
CALLBACKS.register_module(module=TBXLoggerCallback)
11+
12+
13+
def build_callback(cfg: Config) -> LoggerCallback:
14+
return CALLBACKS.build(cfg)

mmtune/ray/callbacks/mlflow.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
from typing import List
2+
3+
from ray.tune.integration.mlflow import \
4+
MLflowLoggerCallback as _MLflowLoggerCallback
5+
from ray.tune.integration.mlflow import logger
6+
from ray.tune.trial import Trial
7+
from ray.tune.utils.util import is_nan_or_inf
8+
9+
from .builder import CALLBACKS
10+
11+
12+
@CALLBACKS.register_module()
13+
class MLflowLoggerCallback(_MLflowLoggerCallback):
14+
"""Custom MLflow Logger to automatically log Tune results and config to
15+
MLflow. The main differences from the original MLflow Logger are:
16+
17+
1. Bind multiple runs into a parent run in the form of nested run.
18+
2. Log artifacts of the best trial to the parent run.
19+
20+
Refer to https://github.com/ray-project/ray/blob/ray-1.9.1/python/ray/tune/integration/mlflow.py for details. # noqa E501
21+
22+
Args:
23+
metric (str): Key for trial info to order on.
24+
mode (str): One of [min, max]. Defaults to ``self.default_mode``.
25+
scope (str): One of [all, last, avg, last-5-avg, last-10-avg].
26+
If `scope=last`, only look at each trial's final step for
27+
`metric`, and compare across trials based on `mode=[min,max]`.
28+
If `scope=avg`, consider the simple average over all steps
29+
for `metric` and compare across trials based on
30+
`mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
31+
consider the simple average over the last 5 or 10 steps for
32+
`metric` and compare across trials based on `mode=[min,max]`.
33+
If `scope=all`, find each trial's min/max score for `metric`
34+
based on `mode`, and compare trials based on `mode=[min,max]`.
35+
filter_nan_and_inf (bool): If True, NaN or infinite values
36+
are disregarded and these trials are never selected as
37+
the best trial. Default: True.
38+
**kwargs: kwargs for original ``MLflowLoggerCallback``
39+
"""
40+
41+
def __init__(self,
42+
metric: str = None,
43+
mode: str = None,
44+
scope: str = 'last',
45+
filter_nan_and_inf: bool = True,
46+
**kwargs):
47+
super(MLflowLoggerCallback, self).__init__(**kwargs)
48+
self.metric = metric
49+
if mode and mode not in ['min', 'max']:
50+
raise ValueError('`mode` has to be None or one of [min, max]')
51+
self.mode = mode
52+
if scope not in ['all', 'last', 'avg', 'last-5-avg', 'last-10-avg']:
53+
raise ValueError(
54+
'ExperimentAnalysis: attempting to get best trial for '
55+
"metric {} for scope {} not in [\"all\", \"last\", \"avg\", "
56+
"\"last-5-avg\", \"last-10-avg\"]. "
57+
"If you didn't pass a `metric` parameter to `tune.run()`, "
58+
'you have to pass one when fetching the best trial.'.format(
59+
self.metric, scope))
60+
self.scope = scope if scope != 'all' else mode
61+
self.filter_nan_and_inf = filter_nan_and_inf
62+
63+
def setup(self, *args, **kwargs):
64+
"""In addition to create `mlflow` experiment, create a parent run to
65+
bind multiple trial runs."""
66+
super().setup(*args, **kwargs)
67+
self.parent_run = self.client.create_run(
68+
experiment_id=self.experiment_id, tags=self.tags)
69+
70+
def log_trial_start(self, trial: 'Trial'):
71+
"""Overrides `log_trial_start` of original `MLflowLoggerCallback` to
72+
set the parent run ID.
73+
74+
Args:
75+
trial (Trial): `ray.tune.trial.Trial`
76+
"""
77+
# Create run if not already exists.
78+
if trial not in self._trial_runs:
79+
80+
# Set trial name in tags
81+
tags = self.tags.copy()
82+
tags['trial_name'] = str(trial)
83+
tags['mlflow.parentRunId'] = self.parent_run.info.run_id
84+
85+
run = self.client.create_run(
86+
experiment_id=self.experiment_id, tags=tags)
87+
self._trial_runs[trial] = run.info.run_id
88+
89+
run_id = self._trial_runs[trial]
90+
91+
# Log the config parameters.
92+
config = trial.config
93+
94+
for key, value in config.items():
95+
self.client.log_param(run_id=run_id, key=key, value=value)
96+
97+
def on_experiment_end(self, trials: List['Trial'], **info):
98+
"""Overrides `Callback` of `Callback` to copy a best trial to parent
99+
run. Called after experiment is over and all trials have concluded.
100+
101+
Args:
102+
trials (List[Trial]): List of trials.
103+
**info: Kwargs dict for forward compatibility.
104+
"""
105+
if not self.metric or not self.mode:
106+
return
107+
108+
best_trial, best_score = None, None
109+
for trial in trials:
110+
if self.metric not in trial.metric_analysis:
111+
continue
112+
113+
score = trial.metric_analysis[self.metric][self.scope]
114+
if self.filter_nan_and_inf and is_nan_or_inf(score):
115+
continue
116+
117+
best_score = best_score or score
118+
if self.mode == 'max' and score >= best_score or (
119+
self.mode == 'min' and score <= best_score):
120+
best_trial, best_score = trial, score
121+
122+
if best_trial is None:
123+
logger.warning(
124+
'Could not find best trial. Did you pass the correct `metric` '
125+
'parameter?')
126+
return
127+
128+
if best_trial not in self._trial_runs:
129+
return
130+
131+
# Copy the run of best trial to parent run.
132+
run_id = self._trial_runs[best_trial]
133+
run = self.client.get_run(run_id)
134+
parent_run_id = self.parent_run.info.run_id
135+
136+
for key, value in run.data.params.items():
137+
self.client.log_param(run_id=parent_run_id, key=key, value=value)
138+
139+
for key, value in run.data.metrics.items():
140+
self.client.log_metric(run_id=parent_run_id, key=key, value=value)
141+
142+
if self.save_artifact:
143+
self.client.log_artifacts(
144+
parent_run_id, local_dir=best_trial.logdir)
145+
146+
self.client.set_terminated(run_id=parent_run_id, status='FINISHED')

0 commit comments

Comments
 (0)