1
- import glob
2
- import re
3
- from os import path as osp
4
- from typing import Dict , List , Optional
1
+ from typing import List
5
2
6
3
from ray .tune .integration .mlflow import \
7
4
MLflowLoggerCallback as _MLflowLoggerCallback
8
5
from ray .tune .integration .mlflow import logger
9
- from ray .tune .result import TIMESTEPS_TOTAL , TRAINING_ITERATION
10
6
from ray .tune .trial import Trial
11
7
from ray .tune .utils .util import is_nan_or_inf
12
8
15
11
16
12
@CALLBACKS .register_module ()
17
13
class MLflowLoggerCallback (_MLflowLoggerCallback ):
14
+ """Custom MLflow Logger to automatically log Tune results and config to
15
+ MLflow. The main differences from the original MLflow Logger are:
16
+
17
+ 1. Bind multiple runs into a parent run in the form of nested run.
18
+ 2. Log artifacts of the best trial to the parent run.
19
+
20
+ Refer to https://github.com/ray-project/ray/blob/ray-1.9.1/python/ray/tune/integration/mlflow.py for details. # noqa E501
21
+
22
+ Args:
23
+ metric: Key for trial info to order on. Defaults to
24
+ ``self.default_metric``.
25
+ mode: One of [min, max]. Defaults to ``self.default_mode``.
26
+ scope: One of [all, last, avg, last-5-avg, last-10-avg].
27
+ If `scope=last`, only look at each trial's final step for
28
+ `metric`, and compare across trials based on `mode=[min,max]`.
29
+ If `scope=avg`, consider the simple average over all steps
30
+ for `metric` and compare across trials based on
31
+ `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
32
+ consider the simple average over the last 5 or 10 steps for
33
+ `metric` and compare across trials based on `mode=[min,max]`.
34
+ If `scope=all`, find each trial's min/max score for `metric`
35
+ based on `mode`, and compare trials based on `mode=[min,max]`.
36
+ filter_nan_and_inf (bool): If True, NaN or infinite values
37
+ are disregarded and these trials are never selected as
38
+ the best trial. Default: True.
39
+ **kwargs: kwargs for original ``MLflowLoggerCallback``
40
+ """
18
41
19
42
def __init__ (self ,
20
- work_dir : Optional [str ],
21
43
metric : str = None ,
22
44
mode : str = None ,
23
45
scope : str = 'last' ,
24
46
filter_nan_and_inf : bool = True ,
25
47
** kwargs ):
26
- super ().__init__ (** kwargs )
27
- self .work_dir = work_dir
48
+ super (MLflowLoggerCallback , self ).__init__ (** kwargs )
28
49
self .metric = metric
29
50
if mode and mode not in ['min' , 'max' ]:
30
51
raise ValueError ('`mode` has to be None or one of [min, max]' )
@@ -41,17 +62,15 @@ def __init__(self,
41
62
self .filter_nan_and_inf = filter_nan_and_inf
42
63
43
64
def setup (self , * args , ** kwargs ):
44
- cp_trial_runs = getattr (self , '_trial_runs' , dict ()).copy ()
45
65
super ().setup (* args , ** kwargs )
46
- self ._trial_runs = cp_trial_runs
47
66
self .parent_run = self .client .create_run (
48
67
experiment_id = self .experiment_id , tags = self .tags )
49
68
50
69
def log_trial_start (self , trial : 'Trial' ):
51
70
# Create run if not already exists.
52
71
if trial not in self ._trial_runs :
53
72
54
- # Set trial name in tags.
73
+ # Set trial name in tags
55
74
tags = self .tags .copy ()
56
75
tags ['trial_name' ] = str (trial )
57
76
tags ['mlflow.parentRunId' ] = self .parent_run .info .run_id
@@ -66,37 +85,8 @@ def log_trial_start(self, trial: 'Trial'):
66
85
config = trial .config
67
86
68
87
for key , value in config .items ():
69
- key = re .sub (r'[^a-zA-Z0-9_=./\s]' , '' , key )
70
88
self .client .log_param (run_id = run_id , key = key , value = value )
71
89
72
- def log_trial_result (self , iteration : int , trial : 'Trial' , result : Dict ):
73
- step = result .get (TIMESTEPS_TOTAL ) or result [TRAINING_ITERATION ]
74
- run_id = self ._trial_runs [trial ]
75
- for key , value in result .items ():
76
- key = re .sub (r'[^a-zA-Z0-9_=./\s]' , '' , key )
77
- try :
78
- value = float (value )
79
- except (ValueError , TypeError ):
80
- logger .debug ('Cannot log key {} with value {} since the '
81
- 'value cannot be converted to float.' .format (
82
- key , value ))
83
- continue
84
-
85
- self .client .log_metric (
86
- run_id = run_id , key = key , value = value , step = step )
87
-
88
- def log_trial_end (self , trial : 'Trial' , failed : bool = False ):
89
- run_id = self ._trial_runs [trial ]
90
- trial_id = trial .trial_id
91
- work_dir = osp .join (self .work_dir , trial_id )
92
- config = glob .glob (osp .join (work_dir , '*.py' ))
93
- if config :
94
- self .client .log_artifact (run_id , local_path = config .pop ())
95
-
96
- # Stop the run once trial finishes.
97
- status = 'FINISHED' if not failed else 'FAILED'
98
- self .client .set_terminated (run_id = run_id , status = status )
99
-
100
90
def on_experiment_end (self , trials : List ['Trial' ], ** info ):
101
91
if not self .metric or not self .mode :
102
92
return
@@ -124,22 +114,19 @@ def on_experiment_end(self, trials: List['Trial'], **info):
124
114
if best_trial not in self ._trial_runs :
125
115
return
126
116
117
+ # Copy the run of best trial to parent run.
127
118
run_id = self ._trial_runs [best_trial ]
128
119
run = self .client .get_run (run_id )
129
120
parent_run_id = self .parent_run .info .run_id
121
+
130
122
for key , value in run .data .params .items ():
131
123
self .client .log_param (run_id = parent_run_id , key = key , value = value )
124
+
132
125
for key , value in run .data .metrics .items ():
133
126
self .client .log_metric (run_id = parent_run_id , key = key , value = value )
134
127
135
- trial_id = best_trial .trial_id
136
- work_dir = osp .join (self .work_dir , trial_id )
137
- config = glob .glob (osp .join (work_dir , '*.py' ))
138
- if config :
139
- self .client .log_artifact (parent_run_id , local_path = config .pop ())
140
-
141
- checkpoints = glob .glob (osp .join (work_dir , '*.pth' ))
142
- for checkpoint in checkpoints :
143
- self .client .log_artifact (parent_run_id , local_path = checkpoint )
128
+ if self .save_artifact :
129
+ self .client .log_artifacts (
130
+ parent_run_id , local_dir = best_trial .logdir )
144
131
145
132
self .client .set_terminated (run_id = parent_run_id , status = 'FINISHED' )
0 commit comments