Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit a295084

Browse files
committed
Added more measurement info like p50, p90 (#87)
1 parent 2d28741 commit a295084

File tree

1 file changed

+8
-28
lines changed

1 file changed

+8
-28
lines changed

dl_bench/utils.py

Lines changed: 8 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -395,34 +395,14 @@ def inference(self, backend: Backend):
395395
self.net.eval()
396396
with torch.no_grad():
397397
start = time.perf_counter()
398-
# Duration is inconsistent now
399-
with tm.timeit("duration_s"):
400-
for i, x in enumerate(test_loader):
401-
backend.sync()
402-
s = get_time()
403-
x = backend.to_device(x)
404-
if backend.dtype != torch.float32:
405-
with torch.autocast(
406-
device_type=backend.device_name,
407-
dtype=backend.dtype,
408-
):
409-
y = self.net(x)
410-
else:
411-
y = self.net(x)
412-
413-
if i < self.warmup_batches:
414-
start = time.perf_counter()
415-
continue
416-
417-
backend.sync()
418-
fw_times.append(get_time() - s)
419-
n_items += len(x)
420-
outputs.append(y)
421-
422-
# early stopping if we have 10+ batches and were running for 10+ seconds
423-
if (
424-
(time.perf_counter() - start) > self.min_seconds
425-
and n_items >= self.batch_size * self.min_batches
398+
for i, x in enumerate(test_loader):
399+
backend.sync()
400+
s = get_time()
401+
x = backend.to_device(x)
402+
if backend.dtype != torch.float32:
403+
with torch.autocast(
404+
device_type=backend.device_name,
405+
dtype=backend.dtype,
426406
):
427407
y = self.net(x)
428408
else:

0 commit comments

Comments
 (0)