@@ -42,15 +42,15 @@ class EMAConfig(BaseModel):
4242 def decay_check (cls , v ):
4343 if v <= 0 or v >= 1 :
4444 raise ValueError (
45- f"'decay' should be in (0, 1) when is type of float, but got { v } "
45+ f"'ema. decay' should be in (0, 1) when is type of float, but got { v } "
4646 )
4747 return v
4848
4949 @field_validator ("avg_freq" )
5050 def avg_freq_check (cls , v ):
5151 if v <= 0 :
5252 raise ValueError (
53- "'avg_freq' should be a positive integer when is type of int, "
53+ "'ema. avg_freq' should be a positive integer when is type of int, "
5454 f"but got { v } "
5555 )
5656 return v
@@ -63,15 +63,17 @@ class SWAConfig(BaseModel):
6363 @field_validator ("avg_range" )
6464 def avg_range_check (cls , v , info : ValidationInfo ):
6565 if isinstance (v , tuple ) and v [0 ] > v [1 ]:
66- raise ValueError (f"'avg_range' should be a valid range, but got { v } ." )
66+ raise ValueError (
67+ f"'swa.avg_range' should be a valid range, but got { v } ."
68+ )
6769 if isinstance (v , tuple ) and v [0 ] < 0 :
6870 raise ValueError (
69- "The start epoch of 'avg_range' should be a non-negtive integer"
71+ "The start epoch of 'swa. avg_range' should be a non-negtive integer"
7072 f" , but got { v [0 ]} ."
7173 )
7274 if isinstance (v , tuple ) and v [1 ] > info .data ["epochs" ]:
7375 raise ValueError (
74- "The end epoch of 'avg_range' should not be lager than "
76+ "The end epoch of 'swa. avg_range' should not be lager than "
7577 f"'epochs'({ info .data ['epochs' ]} ), but got { v [1 ]} ."
7678 )
7779 return v
@@ -80,7 +82,7 @@ def avg_range_check(cls, v, info: ValidationInfo):
8082 def avg_freq_check (cls , v ):
8183 if v <= 0 :
8284 raise ValueError (
83- "'avg_freq' should be a positive integer when is type of int, "
85+ "'swa. avg_freq' should be a positive integer when is type of int, "
8486 f"but got { v } "
8587 )
8688 return v
@@ -107,7 +109,7 @@ class TrainConfig(BaseModel):
107109 def epochs_check (cls , v ):
108110 if v <= 0 :
109111 raise ValueError (
110- "'epochs' should be a positive integer when is type of int, "
112+ "'TRAIN. epochs' should be a positive integer when is type of int, "
111113 f"but got { v } "
112114 )
113115 return v
@@ -116,7 +118,7 @@ def epochs_check(cls, v):
116118 def iters_per_epoch_check (cls , v ):
117119 if v <= 0 :
118120 raise ValueError (
119- "'iters_per_epoch' should be a positive integer when is type of int"
121+ "'TRAIN. iters_per_epoch' should be a positive integer when is type of int"
120122 f", but got { v } "
121123 )
122124 return v
@@ -125,7 +127,7 @@ def iters_per_epoch_check(cls, v):
125127 def update_freq_check (cls , v ):
126128 if v <= 0 :
127129 raise ValueError (
128- "'update_freq' should be a positive integer when is type of int"
130+ "'TRAIN. update_freq' should be a positive integer when is type of int"
129131 f", but got { v } "
130132 )
131133 return v
@@ -134,7 +136,7 @@ def update_freq_check(cls, v):
134136 def save_freq_check (cls , v ):
135137 if v < 0 :
136138 raise ValueError (
137- "'save_freq' should be a non-negtive integer when is type of int"
139+ "'TRAIN. save_freq' should be a non-negtive integer when is type of int"
138140 f", but got { v } "
139141 )
140142 return v
@@ -144,8 +146,8 @@ def start_eval_epoch_check(cls, v, info: ValidationInfo):
144146 if info .data ["eval_during_train" ]:
145147 if v <= 0 :
146148 raise ValueError (
147- f"'start_eval_epoch' should be a positive integer when "
148- f"'eval_during_train' is True, but got { v } "
149+ f"'TRAIN. start_eval_epoch' should be a positive integer when "
150+ f"'TRAIN. eval_during_train' is True, but got { v } "
149151 )
150152 return v
151153
@@ -154,8 +156,8 @@ def eval_freq_check(cls, v, info: ValidationInfo):
154156 if info .data ["eval_during_train" ]:
155157 if v <= 0 :
156158 raise ValueError (
157- f"'eval_freq' should be a positive integer when "
158- f"'eval_during_train' is True, but got { v } "
159+ f"'TRAIN. eval_freq' should be a positive integer when "
160+ f"'TRAIN. eval_during_train' is True, but got { v } "
159161 )
160162 return v
161163
@@ -176,6 +178,15 @@ class EvalConfig(BaseModel):
176178 pretrained_model_path : Optional [str ] = None
177179 eval_with_no_grad : bool = False
178180 compute_metric_by_batch : bool = False
181+ batch_size : Optional [int ] = 256
182+
183+ @field_validator ("batch_size" )
184+ def batch_size_check (cls , v ):
185+ if isinstance (v , int ) and v <= 0 :
186+ raise ValueError (
187+ f"'EVAL.batch_size' should be greater than 0 or None, but got { v } "
188+ )
189+ return v
179190
180191 class InferConfig (BaseModel ):
181192 """
@@ -203,12 +214,12 @@ class InferConfig(BaseModel):
203214 def engine_check (cls , v , info : ValidationInfo ):
204215 if v == "tensorrt" and info .data ["device" ] != "gpu" :
205216 raise ValueError (
206- "'device' should be 'gpu' when 'engine' is 'tensorrt', "
217+ "'INFER. device' should be 'gpu' when 'INFER. engine' is 'tensorrt', "
207218 f"but got '{ info .data ['device' ]} '"
208219 )
209220 if v == "mkldnn" and info .data ["device" ] != "cpu" :
210221 raise ValueError (
211- "'device' should be 'cpu' when 'engine' is 'mkldnn', "
222+ "'INFER. device' should be 'cpu' when 'INFER. engine' is 'mkldnn', "
212223 f"but got '{ info .data ['device' ]} '"
213224 )
214225
@@ -218,46 +229,50 @@ def engine_check(cls, v, info: ValidationInfo):
218229 def min_subgraph_size_check (cls , v ):
219230 if v <= 0 :
220231 raise ValueError (
221- "'min_subgraph_size' should be greater than 0, " f"but got { v } "
232+ "'INFER.min_subgraph_size' should be greater than 0, "
233+ f"but got { v } "
222234 )
223235 return v
224236
225237 @field_validator ("gpu_mem" )
226238 def gpu_mem_check (cls , v ):
227239 if v <= 0 :
228- raise ValueError ("'gpu_mem' should be greater than 0, " f"but got { v } " )
240+ raise ValueError (
241+ "'INFER.gpu_mem' should be greater than 0, " f"but got { v } "
242+ )
229243 return v
230244
231245 @field_validator ("gpu_id" )
232246 def gpu_id_check (cls , v ):
233247 if v < 0 :
234248 raise ValueError (
235- "'gpu_id' should be greater than or equal to 0, " f"but got { v } "
249+ "'INFER.gpu_id' should be greater than or equal to 0, "
250+ f"but got { v } "
236251 )
237252 return v
238253
239254 @field_validator ("max_batch_size" )
240255 def max_batch_size_check (cls , v ):
241256 if v <= 0 :
242257 raise ValueError (
243- "'max_batch_size' should be greater than 0, " f"but got { v } "
258+ "'INFER. max_batch_size' should be greater than 0, " f"but got { v } "
244259 )
245260 return v
246261
247262 @field_validator ("num_cpu_threads" )
248263 def num_cpu_threads_check (cls , v ):
249264 if v < 0 :
250265 raise ValueError (
251- "'num_cpu_threads' should be greater than or equal to 0, "
266+ "'INFER. num_cpu_threads' should be greater than or equal to 0, "
252267 f"but got { v } "
253268 )
254269 return v
255270
256271 @field_validator ("batch_size" )
257272 def batch_size_check (cls , v ):
258- if v <= 0 :
273+ if isinstance ( v , int ) and v <= 0 :
259274 raise ValueError (
260- "' batch_size' should be greater than 0, " f" but got { v } "
275+ f"'INFER. batch_size' should be greater than 0 or None, but got { v } "
261276 )
262277 return v
263278
@@ -326,7 +341,8 @@ def use_wandb_check(cls, v, info: ValidationInfo):
326341 - TRAIN/swa: swa_default <-- 'swa_default' used here
327342 - EVAL: eval_default <-- 'eval_default' used here
328343 - INFER: infer_default <-- 'infer_default' used here
329- - _self_
344+ - _self_ <-- config defined in current yaml
345+
330346 mode: train
331347 seed: 42
332348 ...
@@ -384,6 +400,7 @@ def use_wandb_check(cls, v, info: ValidationInfo):
384400 "EVAL.pretrained_model_path" ,
385401 "EVAL.eval_with_no_grad" ,
386402 "EVAL.compute_metric_by_batch" ,
403+ "EVAL.batch_size" ,
387404 "INFER.pretrained_model_path" ,
388405 "INFER.export_path" ,
389406 "INFER.pdmodel_path" ,
0 commit comments