3030
3131
3232def update_train_loss (
33- trainer : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
33+ solver : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
3434):
3535 for key in loss_dict :
36- if key not in trainer .train_output_info :
37- trainer .train_output_info [key ] = misc .AverageMeter (key , "7.5f" )
38- trainer .train_output_info [key ].update (float (loss_dict [key ]), batch_size )
39- if key not in trainer .train_loss_info :
40- trainer .train_loss_info [key ] = misc .AverageMeter (key , ".5f" )
41- trainer .train_loss_info [key ].update (float (loss_dict [key ]))
36+ if key not in solver .train_output_info :
37+ solver .train_output_info [key ] = misc .AverageMeter (key , "7.5f" )
38+ solver .train_output_info [key ].update (float (loss_dict [key ]), batch_size )
39+ if key not in solver .train_loss_info :
40+ solver .train_loss_info [key ] = misc .AverageMeter (key , ".5f" )
41+ solver .train_loss_info [key ].update (float (loss_dict [key ]))
4242
4343
4444def update_eval_loss (
45- trainer : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
45+ solver : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
4646):
4747 for key in loss_dict :
48- if key not in trainer .eval_output_info :
49- trainer .eval_output_info [key ] = misc .AverageMeter (key , "7.5f" )
50- trainer .eval_output_info [key ].update (float (loss_dict [key ]), batch_size )
48+ if key not in solver .eval_output_info :
49+ solver .eval_output_info [key ] = misc .AverageMeter (key , "7.5f" )
50+ solver .eval_output_info [key ].update (float (loss_dict [key ]), batch_size )
5151
5252
5353def log_train_info (
54- trainer : "solver.Solver" , batch_size : int , epoch_id : int , iter_id : int
54+ solver : "solver.Solver" , batch_size : int , epoch_id : int , iter_id : int
5555):
56- lr_msg = f"lr: { trainer .optimizer .get_lr ():.5f} "
56+ lr_msg = f"lr: { solver .optimizer .get_lr ():.5f} "
5757
5858 metric_msg = ", " .join (
5959 [
60- f"{ key } : { trainer .train_output_info [key ].avg :.5f} "
61- for key in trainer .train_output_info
60+ f"{ key } : { solver .train_output_info [key ].avg :.5f} "
61+ for key in solver .train_output_info
6262 ]
6363 )
6464
6565 time_msg = ", " .join (
66- [trainer .train_time_info [key ].mean for key in trainer .train_time_info ]
66+ [solver .train_time_info [key ].mean for key in solver .train_time_info ]
6767 )
6868
69- ips_msg = f"ips: { batch_size / trainer .train_time_info ['batch_cost' ].avg :.2f} "
70- if trainer .benchmark_flag :
69+ ips_msg = f"ips: { batch_size / solver .train_time_info ['batch_cost' ].avg :.2f} "
70+ if solver .benchmark_flag :
7171 ips_msg += " samples/s"
7272
7373 eta_sec = (
74- (trainer .epochs - epoch_id + 1 ) * trainer .iters_per_epoch - iter_id
75- ) * trainer .train_time_info ["batch_cost" ].avg
74+ (solver .epochs - epoch_id + 1 ) * solver .iters_per_epoch - iter_id
75+ ) * solver .train_time_info ["batch_cost" ].avg
7676 eta_msg = f"eta: { str (datetime .timedelta (seconds = int (eta_sec )))} "
7777
78- epoch_width = len (str (trainer .epochs ))
79- iters_width = len (str (trainer .iters_per_epoch ))
78+ epoch_width = len (str (solver .epochs ))
79+ iters_width = len (str (solver .iters_per_epoch ))
8080 log_str = (
81- f"[Train][Epoch { epoch_id :>{epoch_width }} /{ trainer .epochs } ]"
82- f"[Iter { iter_id :>{iters_width }} /{ trainer .iters_per_epoch } ] { lr_msg } , "
81+ f"[Train][Epoch { epoch_id :>{epoch_width }} /{ solver .epochs } ]"
82+ f"[Iter { iter_id :>{iters_width }} /{ solver .iters_per_epoch } ] { lr_msg } , "
8383 f"{ metric_msg } , { time_msg } , { ips_msg } , { eta_msg } "
8484 )
85- if trainer .benchmark_flag :
85+ if solver .benchmark_flag :
8686 max_mem_reserved_msg = (
8787 f"max_mem_reserved: { device .cuda .max_memory_reserved () // (1 << 20 )} MB"
8888 )
@@ -94,57 +94,57 @@ def log_train_info(
9494
9595 logger .scalar (
9696 {
97- "train/lr" : trainer .optimizer .get_lr (),
97+ "train/lr" : solver .optimizer .get_lr (),
9898 ** {
99- f"train/{ key } " : trainer .train_output_info [key ].avg
100- for key in trainer .train_output_info
99+ f"train/{ key } " : solver .train_output_info [key ].avg
100+ for key in solver .train_output_info
101101 },
102102 },
103- step = trainer .global_step ,
104- vdl_writer = trainer .vdl_writer ,
105- wandb_writer = trainer .wandb_writer ,
106- tbd_writer = trainer .tbd_writer ,
103+ step = solver .global_step ,
104+ vdl_writer = solver .vdl_writer ,
105+ wandb_writer = solver .wandb_writer ,
106+ tbd_writer = solver .tbd_writer ,
107107 )
108108
109109
110110def log_eval_info (
111- trainer : "solver.Solver" ,
111+ solver : "solver.Solver" ,
112112 batch_size : int ,
113113 epoch_id : int ,
114114 iters_per_epoch : int ,
115115 iter_id : int ,
116116):
117117 metric_msg = ", " .join (
118118 [
119- f"{ key } : { trainer .eval_output_info [key ].avg :.5f} "
120- for key in trainer .eval_output_info
119+ f"{ key } : { solver .eval_output_info [key ].avg :.5f} "
120+ for key in solver .eval_output_info
121121 ]
122122 )
123123
124124 time_msg = ", " .join (
125- [trainer .eval_time_info [key ].mean for key in trainer .eval_time_info ]
125+ [solver .eval_time_info [key ].mean for key in solver .eval_time_info ]
126126 )
127127
128- ips_msg = f"ips: { batch_size / trainer .eval_time_info ['batch_cost' ].avg :.2f} "
128+ ips_msg = f"ips: { batch_size / solver .eval_time_info ['batch_cost' ].avg :.2f} "
129129
130- eta_sec = (iters_per_epoch - iter_id ) * trainer .eval_time_info ["batch_cost" ].avg
130+ eta_sec = (iters_per_epoch - iter_id ) * solver .eval_time_info ["batch_cost" ].avg
131131 eta_msg = f"eta: { str (datetime .timedelta (seconds = int (eta_sec )))} "
132132
133- epoch_width = len (str (trainer .epochs ))
133+ epoch_width = len (str (solver .epochs ))
134134 iters_width = len (str (iters_per_epoch ))
135135 logger .info (
136- f"[Eval][Epoch { epoch_id :>{epoch_width }} /{ trainer .epochs } ]"
136+ f"[Eval][Epoch { epoch_id :>{epoch_width }} /{ solver .epochs } ]"
137137 f"[Iter { iter_id :>{iters_width }} /{ iters_per_epoch } ] "
138138 f"{ metric_msg } , { time_msg } , { ips_msg } , { eta_msg } "
139139 )
140140
141141 logger .scalar (
142142 {
143- f"eval/{ key } " : trainer .eval_output_info [key ].avg
144- for key in trainer .eval_output_info
143+ f"eval/{ key } " : solver .eval_output_info [key ].avg
144+ for key in solver .eval_output_info
145145 },
146- step = trainer .global_step ,
147- vdl_writer = trainer .vdl_writer ,
148- wandb_writer = trainer .wandb_writer ,
149- tbd_writer = trainer .tbd_writer ,
146+ step = solver .global_step ,
147+ vdl_writer = solver .vdl_writer ,
148+ wandb_writer = solver .wandb_writer ,
149+ tbd_writer = solver .tbd_writer ,
150150 )
0 commit comments