This repository was archived by the owner on Jul 24, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +28
-8
lines changed
Expand file tree Collapse file tree 1 file changed +28
-8
lines changed Original file line number Diff line number Diff line change @@ -395,14 +395,34 @@ def inference(self, backend: Backend):
395395 self .net .eval ()
396396 with torch .no_grad ():
397397 start = time .perf_counter ()
398- for i , x in enumerate (test_loader ):
399- backend .sync ()
400- s = get_time ()
401- x = backend .to_device (x )
402- if backend .dtype != torch .float32 :
403- with torch .autocast (
404- device_type = backend .device_name ,
405- dtype = backend .dtype ,
398+ # Duration is inconsistent now
399+ with tm .timeit ("duration_s" ):
400+ for i , x in enumerate (test_loader ):
401+ backend .sync ()
402+ s = get_time ()
403+ x = backend .to_device (x )
404+ if backend .dtype != torch .float32 :
405+ with torch .autocast (
406+ device_type = backend .device_name ,
407+ dtype = backend .dtype ,
408+ ):
409+ y = self .net (x )
410+ else :
411+ y = self .net (x )
412+
413+ if i < self .warmup_batches :
414+ start = time .perf_counter ()
415+ continue
416+
417+ backend .sync ()
418+ fw_times .append (get_time () - s )
419+ n_items += len (x )
420+ outputs .append (y )
421+
422+ # early stopping if we have 10+ batches and were running for 10+ seconds
423+ if (
424+ (time .perf_counter () - start ) > self .min_seconds
425+ and n_items >= self .batch_size * self .min_batches
406426 ):
407427 y = self .net (x )
408428 else :
You can’t perform that action at this time.
0 commit comments