Skip to content

Commit 4d71e85

Browse files
KKIEEKKKIEEK
authored and
KKIEEK
committed
Refactor
1 parent 03439ae commit 4d71e85

File tree

1 file changed

+34
-106
lines changed

1 file changed

+34
-106
lines changed

mmtune/ray/callbacks/mlflow.py

+34-106
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,50 @@
1-
import glob
2-
import re
3-
import shutil
4-
import tempfile
5-
import threading
6-
from os import path as osp
7-
from typing import Dict, List, Optional
1+
from typing import List
82

93
from ray.tune.integration.mlflow import \
104
MLflowLoggerCallback as _MLflowLoggerCallback
115
from ray.tune.integration.mlflow import logger
12-
from ray.tune.result import TIMESTEPS_TOTAL, TRAINING_ITERATION
136
from ray.tune.trial import Trial
147
from ray.tune.utils.util import is_nan_or_inf
158

169
from .builder import CALLBACKS
1710

1811

19-
def _create_temporary_copy(path, temp_file_name):
20-
temp_dir = tempfile.gettempdir()
21-
temp_path = osp.join(temp_dir, temp_file_name)
22-
shutil.copy2(path, temp_path)
23-
return temp_path
24-
25-
2612
@CALLBACKS.register_module()
2713
class MLflowLoggerCallback(_MLflowLoggerCallback):
28-
29-
TRIAL_LIMIT = 5
14+
"""Custom MLflow Logger to automatically log Tune results and config to
15+
MLflow. The main differences from the original MLflow Logger are:
16+
17+
1. Bind multiple runs into a parent run in the form of nested run.
18+
2. Log artifacts of the best trial to the parent run.
19+
20+
Refer to https://github.com/ray-project/ray/blob/ray-1.9.1/python/ray/tune/integration/mlflow.py for details. # noqa E501
21+
22+
Args:
23+
metric (str): Key for trial info to order on.
24+
mode (str): One of [min, max]. Defaults to ``self.default_mode``.
25+
scope (str): One of [all, last, avg, last-5-avg, last-10-avg].
26+
If `scope=last`, only look at each trial's final step for
27+
`metric`, and compare across trials based on `mode=[min,max]`.
28+
If `scope=avg`, consider the simple average over all steps
29+
for `metric` and compare across trials based on
30+
`mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
31+
consider the simple average over the last 5 or 10 steps for
32+
`metric` and compare across trials based on `mode=[min,max]`.
33+
If `scope=all`, find each trial's min/max score for `metric`
34+
based on `mode`, and compare trials based on `mode=[min,max]`.
35+
filter_nan_and_inf (bool): If True, NaN or infinite values
36+
are disregarded and these trials are never selected as
37+
the best trial. Default: True.
38+
**kwargs: kwargs for original ``MLflowLoggerCallback``
39+
"""
3040

3141
def __init__(self,
32-
work_dir: Optional[str],
3342
metric: str = None,
3443
mode: str = None,
3544
scope: str = 'last',
3645
filter_nan_and_inf: bool = True,
3746
**kwargs):
38-
super().__init__(**kwargs)
39-
self.work_dir = work_dir
47+
super(MLflowLoggerCallback, self).__init__(**kwargs)
4048
self.metric = metric
4149
if mode and mode not in ['min', 'max']:
4250
raise ValueError('`mode` has to be None or one of [min, max]')
@@ -51,20 +59,17 @@ def __init__(self,
5159
self.metric, scope))
5260
self.scope = scope if scope != 'all' else mode
5361
self.filter_nan_and_inf = filter_nan_and_inf
54-
self.thrs = []
5562

5663
def setup(self, *args, **kwargs):
57-
cp_trial_runs = getattr(self, '_trial_runs', dict()).copy()
5864
super().setup(*args, **kwargs)
59-
self._trial_runs = cp_trial_runs
6065
self.parent_run = self.client.create_run(
6166
experiment_id=self.experiment_id, tags=self.tags)
6267

6368
def log_trial_start(self, trial: 'Trial'):
6469
# Create run if not already exists.
6570
if trial not in self._trial_runs:
6671

67-
# Set trial name in tags.
72+
# Set trial name in tags
6873
tags = self.tags.copy()
6974
tags['trial_name'] = str(trial)
7075
tags['mlflow.parentRunId'] = self.parent_run.info.run_id
@@ -79,86 +84,9 @@ def log_trial_start(self, trial: 'Trial'):
7984
config = trial.config
8085

8186
for key, value in config.items():
82-
key = re.sub(r'[^a-zA-Z0-9_=./\s]', '', key)
8387
self.client.log_param(run_id=run_id, key=key, value=value)
8488

85-
def log_trial_result(self, iteration: int, trial: 'Trial', result: Dict):
86-
step = result.get(TIMESTEPS_TOTAL) or result[TRAINING_ITERATION]
87-
run_id = self._trial_runs[trial]
88-
for key, value in result.items():
89-
key = re.sub(r'[^a-zA-Z0-9_=./\s]', '', key)
90-
try:
91-
value = float(value)
92-
except (ValueError, TypeError):
93-
logger.debug('Cannot log key {} with value {} since the '
94-
'value cannot be converted to float.'.format(
95-
key, value))
96-
continue
97-
for idx in range(MLflowLoggerCallback.TRIAL_LIMIT):
98-
try:
99-
self.client.log_metric(
100-
run_id=run_id, key=key, value=value, step=step)
101-
except Exception as ex:
102-
print(ex)
103-
print(f'Retrying ... : {idx+1}')
104-
105-
def log_trial_end(self, trial: 'Trial', failed: bool = False):
106-
107-
def log_artifacts(run_id,
108-
path,
109-
trial_limit=MLflowLoggerCallback.TRIAL_LIMIT):
110-
for idx in range(trial_limit):
111-
try:
112-
self.client.log_artifact(
113-
run_id, local_path=path, artifact_path='checkpoint')
114-
except Exception as ex:
115-
print(ex)
116-
print(f'Retrying ... : {idx+1}')
117-
118-
run_id = self._trial_runs[trial]
119-
120-
if self.save_artifact:
121-
trial_id = trial.trial_id
122-
work_dir = osp.join(self.work_dir, trial_id)
123-
checkpoints = glob.glob(osp.join(work_dir, '*.pth'))
124-
if checkpoints:
125-
pth = _create_temporary_copy(
126-
max(checkpoints, key=osp.getctime), 'model_final.pth')
127-
th = threading.Thread(target=log_artifacts, args=(run_id, pth))
128-
self.thrs.append(th)
129-
th.start()
130-
131-
cfg = _create_temporary_copy(
132-
glob.glob(osp.join(work_dir, '*.py'))[0], 'model_config.py')
133-
if cfg:
134-
th = threading.Thread(target=log_artifacts, args=(run_id, cfg))
135-
self.thrs.append(th)
136-
th.start()
137-
138-
# Stop the run once trial finishes.
139-
status = 'FINISHED' if not failed else 'FAILED'
140-
self.client.set_terminated(run_id=run_id, status=status)
141-
14289
def on_experiment_end(self, trials: List['Trial'], **info):
143-
for th in self.thrs:
144-
th.join()
145-
146-
def cp_artifacts(src_run_id,
147-
dst_run_id,
148-
tmp_dir,
149-
trial_limit=MLflowLoggerCallback.TRIAL_LIMIT):
150-
for idx in range(trial_limit):
151-
try:
152-
self.client.download_artifacts(
153-
run_id=src_run_id, path='checkpoint', dst_path=tmp_dir)
154-
self.client.log_artifacts(
155-
run_id=dst_run_id,
156-
local_dir=osp.join(tmp_dir, 'checkpoint'),
157-
artifact_path='checkpoint')
158-
except Exception as ex:
159-
print(ex)
160-
print(f'Retrying ... : {idx+1}')
161-
16290
if not self.metric or not self.mode:
16391
return
16492

@@ -185,19 +113,19 @@ def cp_artifacts(src_run_id,
185113
if best_trial not in self._trial_runs:
186114
return
187115

116+
# Copy the run of best trial to parent run.
188117
run_id = self._trial_runs[best_trial]
189118
run = self.client.get_run(run_id)
190119
parent_run_id = self.parent_run.info.run_id
120+
191121
for key, value in run.data.params.items():
192122
self.client.log_param(run_id=parent_run_id, key=key, value=value)
123+
193124
for key, value in run.data.metrics.items():
194125
self.client.log_metric(run_id=parent_run_id, key=key, value=value)
195126

196127
if self.save_artifact:
197-
tmp_dir = tempfile.gettempdir()
198-
th = threading.Thread(
199-
target=cp_artifacts, args=(run_id, parent_run_id, tmp_dir))
200-
th.start()
201-
th.join()
128+
self.client.log_artifacts(
129+
parent_run_id, local_dir=best_trial.logdir)
202130

203131
self.client.set_terminated(run_id=parent_run_id, status='FINISHED')

0 commit comments

Comments
 (0)