1
- import glob
2
- import re
3
- import shutil
4
- import tempfile
5
- import threading
6
- from os import path as osp
7
- from typing import Dict , List , Optional
1
+ from typing import List
8
2
9
3
from ray .tune .integration .mlflow import \
10
4
MLflowLoggerCallback as _MLflowLoggerCallback
11
5
from ray .tune .integration .mlflow import logger
12
- from ray .tune .result import TIMESTEPS_TOTAL , TRAINING_ITERATION
13
6
from ray .tune .trial import Trial
14
7
from ray .tune .utils .util import is_nan_or_inf
15
8
16
9
from .builder import CALLBACKS
17
10
18
11
19
- def _create_temporary_copy (path , temp_file_name ):
20
- temp_dir = tempfile .gettempdir ()
21
- temp_path = osp .join (temp_dir , temp_file_name )
22
- shutil .copy2 (path , temp_path )
23
- return temp_path
24
-
25
-
26
12
@CALLBACKS .register_module ()
27
13
class MLflowLoggerCallback (_MLflowLoggerCallback ):
28
-
29
- TRIAL_LIMIT = 5
14
+ """Custom MLflow Logger to automatically log Tune results and config to
15
+ MLflow. The main differences from the original MLflow Logger are:
16
+
17
+ 1. Bind multiple runs into a parent run in the form of nested run.
18
+ 2. Log artifacts of the best trial to the parent run.
19
+
20
+ Refer to https://github.com/ray-project/ray/blob/ray-1.9.1/python/ray/tune/integration/mlflow.py for details. # noqa E501
21
+
22
+ Args:
23
+ metric (str): Key for trial info to order on.
24
+ mode (str): One of [min, max]. Defaults to ``self.default_mode``.
25
+ scope (str): One of [all, last, avg, last-5-avg, last-10-avg].
26
+ If `scope=last`, only look at each trial's final step for
27
+ `metric`, and compare across trials based on `mode=[min,max]`.
28
+ If `scope=avg`, consider the simple average over all steps
29
+ for `metric` and compare across trials based on
30
+ `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
31
+ consider the simple average over the last 5 or 10 steps for
32
+ `metric` and compare across trials based on `mode=[min,max]`.
33
+ If `scope=all`, find each trial's min/max score for `metric`
34
+ based on `mode`, and compare trials based on `mode=[min,max]`.
35
+ filter_nan_and_inf (bool): If True, NaN or infinite values
36
+ are disregarded and these trials are never selected as
37
+ the best trial. Default: True.
38
+ **kwargs: kwargs for original ``MLflowLoggerCallback``
39
+ """
30
40
31
41
def __init__ (self ,
32
- work_dir : Optional [str ],
33
42
metric : str = None ,
34
43
mode : str = None ,
35
44
scope : str = 'last' ,
36
45
filter_nan_and_inf : bool = True ,
37
46
** kwargs ):
38
- super ().__init__ (** kwargs )
39
- self .work_dir = work_dir
47
+ super (MLflowLoggerCallback , self ).__init__ (** kwargs )
40
48
self .metric = metric
41
49
if mode and mode not in ['min' , 'max' ]:
42
50
raise ValueError ('`mode` has to be None or one of [min, max]' )
@@ -51,20 +59,17 @@ def __init__(self,
51
59
self .metric , scope ))
52
60
self .scope = scope if scope != 'all' else mode
53
61
self .filter_nan_and_inf = filter_nan_and_inf
54
- self .thrs = []
55
62
56
63
def setup (self , * args , ** kwargs ):
57
- cp_trial_runs = getattr (self , '_trial_runs' , dict ()).copy ()
58
64
super ().setup (* args , ** kwargs )
59
- self ._trial_runs = cp_trial_runs
60
65
self .parent_run = self .client .create_run (
61
66
experiment_id = self .experiment_id , tags = self .tags )
62
67
63
68
def log_trial_start (self , trial : 'Trial' ):
64
69
# Create run if not already exists.
65
70
if trial not in self ._trial_runs :
66
71
67
- # Set trial name in tags.
72
+ # Set trial name in tags
68
73
tags = self .tags .copy ()
69
74
tags ['trial_name' ] = str (trial )
70
75
tags ['mlflow.parentRunId' ] = self .parent_run .info .run_id
@@ -79,86 +84,9 @@ def log_trial_start(self, trial: 'Trial'):
79
84
config = trial .config
80
85
81
86
for key , value in config .items ():
82
- key = re .sub (r'[^a-zA-Z0-9_=./\s]' , '' , key )
83
87
self .client .log_param (run_id = run_id , key = key , value = value )
84
88
85
- def log_trial_result (self , iteration : int , trial : 'Trial' , result : Dict ):
86
- step = result .get (TIMESTEPS_TOTAL ) or result [TRAINING_ITERATION ]
87
- run_id = self ._trial_runs [trial ]
88
- for key , value in result .items ():
89
- key = re .sub (r'[^a-zA-Z0-9_=./\s]' , '' , key )
90
- try :
91
- value = float (value )
92
- except (ValueError , TypeError ):
93
- logger .debug ('Cannot log key {} with value {} since the '
94
- 'value cannot be converted to float.' .format (
95
- key , value ))
96
- continue
97
- for idx in range (MLflowLoggerCallback .TRIAL_LIMIT ):
98
- try :
99
- self .client .log_metric (
100
- run_id = run_id , key = key , value = value , step = step )
101
- except Exception as ex :
102
- print (ex )
103
- print (f'Retrying ... : { idx + 1 } ' )
104
-
105
- def log_trial_end (self , trial : 'Trial' , failed : bool = False ):
106
-
107
- def log_artifacts (run_id ,
108
- path ,
109
- trial_limit = MLflowLoggerCallback .TRIAL_LIMIT ):
110
- for idx in range (trial_limit ):
111
- try :
112
- self .client .log_artifact (
113
- run_id , local_path = path , artifact_path = 'checkpoint' )
114
- except Exception as ex :
115
- print (ex )
116
- print (f'Retrying ... : { idx + 1 } ' )
117
-
118
- run_id = self ._trial_runs [trial ]
119
-
120
- if self .save_artifact :
121
- trial_id = trial .trial_id
122
- work_dir = osp .join (self .work_dir , trial_id )
123
- checkpoints = glob .glob (osp .join (work_dir , '*.pth' ))
124
- if checkpoints :
125
- pth = _create_temporary_copy (
126
- max (checkpoints , key = osp .getctime ), 'model_final.pth' )
127
- th = threading .Thread (target = log_artifacts , args = (run_id , pth ))
128
- self .thrs .append (th )
129
- th .start ()
130
-
131
- cfg = _create_temporary_copy (
132
- glob .glob (osp .join (work_dir , '*.py' ))[0 ], 'model_config.py' )
133
- if cfg :
134
- th = threading .Thread (target = log_artifacts , args = (run_id , cfg ))
135
- self .thrs .append (th )
136
- th .start ()
137
-
138
- # Stop the run once trial finishes.
139
- status = 'FINISHED' if not failed else 'FAILED'
140
- self .client .set_terminated (run_id = run_id , status = status )
141
-
142
89
def on_experiment_end (self , trials : List ['Trial' ], ** info ):
143
- for th in self .thrs :
144
- th .join ()
145
-
146
- def cp_artifacts (src_run_id ,
147
- dst_run_id ,
148
- tmp_dir ,
149
- trial_limit = MLflowLoggerCallback .TRIAL_LIMIT ):
150
- for idx in range (trial_limit ):
151
- try :
152
- self .client .download_artifacts (
153
- run_id = src_run_id , path = 'checkpoint' , dst_path = tmp_dir )
154
- self .client .log_artifacts (
155
- run_id = dst_run_id ,
156
- local_dir = osp .join (tmp_dir , 'checkpoint' ),
157
- artifact_path = 'checkpoint' )
158
- except Exception as ex :
159
- print (ex )
160
- print (f'Retrying ... : { idx + 1 } ' )
161
-
162
90
if not self .metric or not self .mode :
163
91
return
164
92
@@ -185,19 +113,19 @@ def cp_artifacts(src_run_id,
185
113
if best_trial not in self ._trial_runs :
186
114
return
187
115
116
+ # Copy the run of best trial to parent run.
188
117
run_id = self ._trial_runs [best_trial ]
189
118
run = self .client .get_run (run_id )
190
119
parent_run_id = self .parent_run .info .run_id
120
+
191
121
for key , value in run .data .params .items ():
192
122
self .client .log_param (run_id = parent_run_id , key = key , value = value )
123
+
193
124
for key , value in run .data .metrics .items ():
194
125
self .client .log_metric (run_id = parent_run_id , key = key , value = value )
195
126
196
127
if self .save_artifact :
197
- tmp_dir = tempfile .gettempdir ()
198
- th = threading .Thread (
199
- target = cp_artifacts , args = (run_id , parent_run_id , tmp_dir ))
200
- th .start ()
201
- th .join ()
128
+ self .client .log_artifacts (
129
+ parent_run_id , local_dir = best_trial .logdir )
202
130
203
131
self .client .set_terminated (run_id = parent_run_id , status = 'FINISHED' )
0 commit comments