Skip to content

Commit 69f23c8

Browse files
KKIEEKKKIEEK
authored and
KKIEEK
committed
Refactor
1 parent 9626dad commit 69f23c8

File tree

1 file changed

+36
-49
lines changed

1 file changed

+36
-49
lines changed

mmtune/ray/callbacks/mlflow.py

+36-49
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
import glob
2-
import re
3-
from os import path as osp
4-
from typing import Dict, List, Optional
1+
from typing import List
52

63
from ray.tune.integration.mlflow import \
74
MLflowLoggerCallback as _MLflowLoggerCallback
85
from ray.tune.integration.mlflow import logger
9-
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
106
from ray.tune.trial import Trial
117
from ray.tune.utils.util import is_nan_or_inf
128

@@ -15,16 +11,41 @@
1511

1612
@CALLBACKS.register_module()
1713
class MLflowLoggerCallback(_MLflowLoggerCallback):
14+
"""Custom MLflow Logger to automatically log Tune results and config to
15+
MLflow. The main differences from the original MLflow Logger are:
16+
17+
1. Bind multiple runs into a parent run in the form of nested run.
18+
2. Log artifacts of the best trial to the parent run.
19+
20+
Refer to https://github.com/ray-project/ray/blob/ray-1.9.1/python/ray/tune/integration/mlflow.py for details. # noqa E501
21+
22+
Args:
23+
metric: Key for trial info to order on. Defaults to
24+
``self.default_metric``.
25+
mode: One of [min, max]. Defaults to ``self.default_mode``.
26+
scope: One of [all, last, avg, last-5-avg, last-10-avg].
27+
If `scope=last`, only look at each trial's final step for
28+
`metric`, and compare across trials based on `mode=[min,max]`.
29+
If `scope=avg`, consider the simple average over all steps
30+
for `metric` and compare across trials based on
31+
`mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
32+
consider the simple average over the last 5 or 10 steps for
33+
`metric` and compare across trials based on `mode=[min,max]`.
34+
If `scope=all`, find each trial's min/max score for `metric`
35+
based on `mode`, and compare trials based on `mode=[min,max]`.
36+
filter_nan_and_inf (bool): If True, NaN or infinite values
37+
are disregarded and these trials are never selected as
38+
the best trial. Default: True.
39+
**kwargs: kwargs for original ``MLflowLoggerCallback``
40+
"""
1841

1942
def __init__(self,
20-
work_dir: Optional[str],
2143
metric: str = None,
2244
mode: str = None,
2345
scope: str = 'last',
2446
filter_nan_and_inf: bool = True,
2547
**kwargs):
26-
super().__init__(**kwargs)
27-
self.work_dir = work_dir
48+
super(MLflowLoggerCallback, self).__init__(**kwargs)
2849
self.metric = metric
2950
if mode and mode not in ['min', 'max']:
3051
raise ValueError('`mode` has to be None or one of [min, max]')
@@ -41,17 +62,15 @@ def __init__(self,
4162
self.filter_nan_and_inf = filter_nan_and_inf
4263

4364
def setup(self, *args, **kwargs):
44-
cp_trial_runs = getattr(self, '_trial_runs', dict()).copy()
4565
super().setup(*args, **kwargs)
46-
self._trial_runs = cp_trial_runs
4766
self.parent_run = self.client.create_run(
4867
experiment_id=self.experiment_id, tags=self.tags)
4968

5069
def log_trial_start(self, trial: 'Trial'):
5170
# Create run if not already exists.
5271
if trial not in self._trial_runs:
5372

54-
# Set trial name in tags.
73+
# Set trial name in tags
5574
tags = self.tags.copy()
5675
tags['trial_name'] = str(trial)
5776
tags['mlflow.parentRunId'] = self.parent_run.info.run_id
@@ -66,37 +85,8 @@ def log_trial_start(self, trial: 'Trial'):
6685
config = trial.config
6786

6887
for key, value in config.items():
69-
key = re.sub(r'[^a-zA-Z0-9_=./\s]', '', key)
7088
self.client.log_param(run_id=run_id, key=key, value=value)
7189

72-
def log_trial_result(self, iteration: int, trial: 'Trial', result: Dict):
73-
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
74-
run_id = self._trial_runs[trial]
75-
for key, value in result.items():
76-
key = re.sub(r'[^a-zA-Z0-9_=./\s]', '', key)
77-
try:
78-
value = float(value)
79-
except (ValueError, TypeError):
80-
logger.debug('Cannot log key {} with value {} since the '
81-
'value cannot be converted to float.'.format(
82-
key, value))
83-
continue
84-
85-
self.client.log_metric(
86-
run_id=run_id, key=key, value=value, step=step)
87-
88-
def log_trial_end(self, trial: 'Trial', failed: bool = False):
89-
run_id = self._trial_runs[trial]
90-
trial_id = trial.trial_id
91-
work_dir = osp.join(self.work_dir, trial_id)
92-
config = glob.glob(osp.join(work_dir, '*.py'))
93-
if config:
94-
self.client.log_artifact(run_id, local_path=config.pop())
95-
96-
# Stop the run once trial finishes.
97-
status = 'FINISHED' if not failed else 'FAILED'
98-
self.client.set_terminated(run_id=run_id, status=status)
99-
10090
def on_experiment_end(self, trials: List['Trial'], **info):
10191
if not self.metric or not self.mode:
10292
return
@@ -124,22 +114,19 @@ def on_experiment_end(self, trials: List['Trial'], **info):
124114
if best_trial not in self._trial_runs:
125115
return
126116

117+
# Copy the run of best trial to parent run.
127118
run_id = self._trial_runs[best_trial]
128119
run = self.client.get_run(run_id)
129120
parent_run_id = self.parent_run.info.run_id
121+
130122
for key, value in run.data.params.items():
131123
self.client.log_param(run_id=parent_run_id, key=key, value=value)
124+
132125
for key, value in run.data.metrics.items():
133126
self.client.log_metric(run_id=parent_run_id, key=key, value=value)
134127

135-
trial_id = best_trial.trial_id
136-
work_dir = osp.join(self.work_dir, trial_id)
137-
config = glob.glob(osp.join(work_dir, '*.py'))
138-
if config:
139-
self.client.log_artifact(parent_run_id, local_path=config.pop())
140-
141-
checkpoints = glob.glob(osp.join(work_dir, '*.pth'))
142-
for checkpoint in checkpoints:
143-
self.client.log_artifact(parent_run_id, local_path=checkpoint)
128+
if self.save_artifact:
129+
self.client.log_artifacts(
130+
parent_run_id, local_dir=best_trial.logdir)
144131

145132
self.client.set_terminated(run_id=parent_run_id, status='FINISHED')

0 commit comments

Comments
 (0)