@@ -343,12 +343,13 @@ def _get_device(device_name):
343343
344344class Benchmark :
345345 def __init__ (
346- self , net , in_shape , dataset , batch_size , min_batches = 10 , min_seconds = 10
346+ self , net , in_shape , dataset , batch_size , min_batches = 10 , min_seconds = 10 , warmup_batches = 3 ,
347347 ) -> None :
348348 self .net = net
349349 self .in_shape = in_shape
350350 self .dataset = dataset
351351 self .batch_size = batch_size
352+ self .warmup_batches = warmup_batches
352353 self .min_batches = min_batches
353354 self .min_seconds = min_seconds
354355
@@ -379,24 +380,6 @@ def inference(self, backend: Backend):
379380 sample = next (iter (test_loader ))
380381 self .compile (sample , backend )
381382
382- print ("Warmup started" )
383- with torch .no_grad (), tm .timeit ("warmup_s" ):
384- self .net .eval ()
385- sample = backend .to_device (sample )
386- if backend .dtype != torch .float32 :
387- with torch .autocast (
388- device_type = backend .device_name ,
389- dtype = backend .dtype ,
390- ):
391- self .net (sample )
392- self .net (sample )
393- self .net (sample )
394- else :
395- self .net (sample )
396- self .net (sample )
397- self .net (sample )
398- print ("Warmup done" )
399-
400383 n_items = 0
401384
402385 self .net .eval ()
@@ -417,15 +400,19 @@ def inference(self, backend: Backend):
417400 y = self .net (x )
418401 else :
419402 y = self .net (x )
420- if i < 3 : continue
403+
404+ if i < self .warmup_batches :
405+ start = time .perf_counter ()
406+ continue
407+
421408 fw_times .append (get_time () - s )
422409 n_items += len (x )
423410 outputs .append (y )
424411
425412 # early stopping if we have 10+ batches and were running for 10+ seconds
426413 if (
427414 (time .perf_counter () - start ) > self .min_seconds
428- and n_items > self .batch_size * self .min_batches
415+ and n_items >= self .batch_size * self .min_batches
429416 ):
430417 break
431418
@@ -437,6 +424,7 @@ def inference(self, backend: Backend):
437424 )
438425
439426 results = tm .get_results ()
427+ results ["duration_s" ] = get_time () - start
440428 results ["samples_per_s" ] = n_items / sum (fw_times )
441429 results ["flops_per_sample" ] = self .flops_per_sample
442430
0 commit comments