Skip to content

Commit 38a4e60

Browse files
KKIEEKKKIEEK
authored and
KKIEEK
committed
Support ray logger callbacks
1 parent efea853 commit 38a4e60

File tree

4 files changed

+228
-1
lines changed

4 files changed

+228
-1
lines changed

mmtune/apis/tune.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from mmcv.utils import Config
66

77
from mmtune.mm.tasks import BaseTask
8+
from mmtune.ray.callbacks.builder import build_callback
89
from mmtune.ray.schedulers import build_scheduler
910
from mmtune.ray.searchers import build_searcher
1011
from mmtune.ray.spaces import build_space
@@ -47,6 +48,10 @@ def tune(task_processor: BaseTask, tune_config: Config,
4748
if scheduler is not None:
4849
scheduler = build_scheduler(scheduler)
4950

51+
callbacks = tune_config.get('callbacks', None)
52+
if callbacks is not None:
53+
callbacks = [build_callback(callback) for callback in callbacks]
54+
5055
return ray.tune.run(
5156
trainable,
5257
name=exp_name,
@@ -59,4 +64,5 @@ def tune(task_processor: BaseTask, tune_config: Config,
5964
local_dir=tune_artifact_dir,
6065
search_alg=searcher,
6166
scheduler=scheduler,
62-
raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False))
67+
raise_on_failed_trial=tune_config.get('raise_on_failed_trial', False),
68+
callbacks=callbacks)

mmtune/ray/callbacks/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .builder import CALLBACKS, build_callback
2+
from .mlflow import MLflowLoggerCallback
3+
4+
__all__ = ['CALLBACKS', 'build_callback', 'MLflowLoggerCallback']

mmtune/ray/callbacks/builder.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from mmcv.utils import Config, Registry
2+
from ray.tune.logger import (CSVLoggerCallback, JsonLoggerCallback,
3+
LegacyLoggerCallback, LoggerCallback,
4+
TBXLoggerCallback)
5+
6+
CALLBACKS = Registry('callbacks')
7+
CALLBACKS.register_module(module=LegacyLoggerCallback)
8+
CALLBACKS.register_module(module=JsonLoggerCallback)
9+
CALLBACKS.register_module(module=CSVLoggerCallback)
10+
CALLBACKS.register_module(module=TBXLoggerCallback)
11+
12+
13+
def build_callback(cfg: Config) -> LoggerCallback:
14+
return CALLBACKS.build(cfg)

mmtune/ray/callbacks/mlflow.py

+203
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import glob
2+
import re
3+
import shutil
4+
import tempfile
5+
import threading
6+
from os import path as osp
7+
from typing import Dict, List, Optional
8+
9+
from ray.tune.integration.mlflow import \
10+
MLflowLoggerCallback as _MLflowLoggerCallback
11+
from ray.tune.integration.mlflow import logger
12+
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
13+
from ray.tune.trial import Trial
14+
from ray.tune.utils.util import is_nan_or_inf
15+
16+
from .builder import CALLBACKS
17+
18+
19+
def _create_temporary_copy(path, temp_file_name):
20+
temp_dir = tempfile.gettempdir()
21+
temp_path = osp.join(temp_dir, temp_file_name)
22+
shutil.copy2(path, temp_path)
23+
return temp_path
24+
25+
26+
@CALLBACKS.register_module()
27+
class MLflowLoggerCallback(_MLflowLoggerCallback):
28+
29+
TRIAL_LIMIT = 5
30+
31+
def __init__(self,
32+
work_dir: Optional[str],
33+
metric: str = None,
34+
mode: str = None,
35+
scope: str = 'last',
36+
filter_nan_and_inf: bool = True,
37+
**kwargs):
38+
super().__init__(**kwargs)
39+
self.work_dir = work_dir
40+
self.metric = metric
41+
if mode and mode not in ['min', 'max']:
42+
raise ValueError('`mode` has to be None or one of [min, max]')
43+
self.mode = mode
44+
if scope not in ['all', 'last', 'avg', 'last-5-avg', 'last-10-avg']:
45+
raise ValueError(
46+
'ExperimentAnalysis: attempting to get best trial for '
47+
"metric {} for scope {} not in [\"all\", \"last\", \"avg\", "
48+
"\"last-5-avg\", \"last-10-avg\"]. "
49+
"If you didn't pass a `metric` parameter to `tune.run()`, "
50+
'you have to pass one when fetching the best trial.'.format(
51+
self.metric, scope))
52+
self.scope = scope if scope != 'all' else mode
53+
self.filter_nan_and_inf = filter_nan_and_inf
54+
self.thrs = []
55+
56+
def setup(self, *args, **kwargs):
57+
cp_trial_runs = getattr(self, '_trial_runs', dict()).copy()
58+
super().setup(*args, **kwargs)
59+
self._trial_runs = cp_trial_runs
60+
self.parent_run = self.client.create_run(
61+
experiment_id=self.experiment_id, tags=self.tags)
62+
63+
def log_trial_start(self, trial: 'Trial'):
64+
# Create run if not already exists.
65+
if trial not in self._trial_runs:
66+
67+
# Set trial name in tags.
68+
tags = self.tags.copy()
69+
tags['trial_name'] = str(trial)
70+
tags['mlflow.parentRunId'] = self.parent_run.info.run_id
71+
72+
run = self.client.create_run(
73+
experiment_id=self.experiment_id, tags=tags)
74+
self._trial_runs[trial] = run.info.run_id
75+
76+
run_id = self._trial_runs[trial]
77+
78+
# Log the config parameters.
79+
config = trial.config
80+
81+
for key, value in config.items():
82+
key = re.sub(r'[^a-zA-Z0-9_=./\s]', '', key)
83+
self.client.log_param(run_id=run_id, key=key, value=value)
84+
85+
def log_trial_result(self, iteration: int, trial: 'Trial', result: Dict):
86+
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
87+
run_id = self._trial_runs[trial]
88+
for key, value in result.items():
89+
key = re.sub(r'[^a-zA-Z0-9_=./\s]', '', key)
90+
try:
91+
value = float(value)
92+
except (ValueError, TypeError):
93+
logger.debug('Cannot log key {} with value {} since the '
94+
'value cannot be converted to float.'.format(
95+
key, value))
96+
continue
97+
for idx in range(MLflowLoggerCallback.TRIAL_LIMIT):
98+
try:
99+
self.client.log_metric(
100+
run_id=run_id, key=key, value=value, step=step)
101+
except Exception as ex:
102+
print(ex)
103+
print(f'Retrying ... : {idx+1}')
104+
105+
def log_trial_end(self, trial: 'Trial', failed: bool = False):
106+
107+
def log_artifacts(run_id,
108+
path,
109+
trial_limit=MLflowLoggerCallback.TRIAL_LIMIT):
110+
for idx in range(trial_limit):
111+
try:
112+
self.client.log_artifact(
113+
run_id, local_path=path, artifact_path='checkpoint')
114+
except Exception as ex:
115+
print(ex)
116+
print(f'Retrying ... : {idx+1}')
117+
118+
run_id = self._trial_runs[trial]
119+
120+
if self.save_artifact:
121+
trial_id = trial.trial_id
122+
work_dir = osp.join(self.work_dir, trial_id)
123+
checkpoints = glob.glob(osp.join(work_dir, '*.pth'))
124+
if checkpoints:
125+
pth = _create_temporary_copy(
126+
max(checkpoints, key=osp.getctime), 'model_final.pth')
127+
th = threading.Thread(target=log_artifacts, args=(run_id, pth))
128+
self.thrs.append(th)
129+
th.start()
130+
131+
cfg = _create_temporary_copy(
132+
glob.glob(osp.join(work_dir, '*.py'))[0], 'model_config.py')
133+
if cfg:
134+
th = threading.Thread(target=log_artifacts, args=(run_id, cfg))
135+
self.thrs.append(th)
136+
th.start()
137+
138+
# Stop the run once trial finishes.
139+
status = 'FINISHED' if not failed else 'FAILED'
140+
self.client.set_terminated(run_id=run_id, status=status)
141+
142+
def on_experiment_end(self, trials: List['Trial'], **info):
143+
for th in self.thrs:
144+
th.join()
145+
146+
def cp_artifacts(src_run_id,
147+
dst_run_id,
148+
tmp_dir,
149+
trial_limit=MLflowLoggerCallback.TRIAL_LIMIT):
150+
for idx in range(trial_limit):
151+
try:
152+
self.client.download_artifacts(
153+
run_id=src_run_id, path='checkpoint', dst_path=tmp_dir)
154+
self.client.log_artifacts(
155+
run_id=dst_run_id,
156+
local_dir=osp.join(tmp_dir, 'checkpoint'),
157+
artifact_path='checkpoint')
158+
except Exception as ex:
159+
print(ex)
160+
print(f'Retrying ... : {idx+1}')
161+
162+
if not self.metric or not self.mode:
163+
return
164+
165+
best_trial, best_score = None, None
166+
for trial in trials:
167+
if self.metric not in trial.metric_analysis:
168+
continue
169+
170+
score = trial.metric_analysis[self.metric][self.scope]
171+
if self.filter_nan_and_inf and is_nan_or_inf(score):
172+
continue
173+
174+
best_score = best_score or score
175+
if self.mode == 'max' and score >= best_score or (
176+
self.mode == 'min' and score <= best_score):
177+
best_trial, best_score = trial, score
178+
179+
if best_trial is None:
180+
logger.warning(
181+
'Could not find best trial. Did you pass the correct `metric` '
182+
'parameter?')
183+
return
184+
185+
if best_trial not in self._trial_runs:
186+
return
187+
188+
run_id = self._trial_runs[best_trial]
189+
run = self.client.get_run(run_id)
190+
parent_run_id = self.parent_run.info.run_id
191+
for key, value in run.data.params.items():
192+
self.client.log_param(run_id=parent_run_id, key=key, value=value)
193+
for key, value in run.data.metrics.items():
194+
self.client.log_metric(run_id=parent_run_id, key=key, value=value)
195+
196+
if self.save_artifact:
197+
tmp_dir = tempfile.gettempdir()
198+
th = threading.Thread(
199+
target=cp_artifacts, args=(run_id, parent_run_id, tmp_dir))
200+
th.start()
201+
th.join()
202+
203+
self.client.set_terminated(run_id=parent_run_id, status='FINISHED')

0 commit comments

Comments
 (0)