Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 2d28741

Browse files
committed
Added sync for nvidia backend (#84)
1 parent 6d5e68a commit 2d28741

File tree

1 file changed

+28
-8
lines changed

1 file changed

+28
-8
lines changed

dl_bench/utils.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -395,14 +395,34 @@ def inference(self, backend: Backend):
395395
self.net.eval()
396396
with torch.no_grad():
397397
start = time.perf_counter()
398-
for i, x in enumerate(test_loader):
399-
backend.sync()
400-
s = get_time()
401-
x = backend.to_device(x)
402-
if backend.dtype != torch.float32:
403-
with torch.autocast(
404-
device_type=backend.device_name,
405-
dtype=backend.dtype,
398+
# Duration is inconsistent now
399+
with tm.timeit("duration_s"):
400+
for i, x in enumerate(test_loader):
401+
backend.sync()
402+
s = get_time()
403+
x = backend.to_device(x)
404+
if backend.dtype != torch.float32:
405+
with torch.autocast(
406+
device_type=backend.device_name,
407+
dtype=backend.dtype,
408+
):
409+
y = self.net(x)
410+
else:
411+
y = self.net(x)
412+
413+
if i < self.warmup_batches:
414+
start = time.perf_counter()
415+
continue
416+
417+
backend.sync()
418+
fw_times.append(get_time() - s)
419+
n_items += len(x)
420+
outputs.append(y)
421+
422+
# early stopping if we have 10+ batches and were running for 10+ seconds
423+
if (
424+
(time.perf_counter() - start) > self.min_seconds
425+
and n_items >= self.batch_size * self.min_batches
406426
):
407427
y = self.net(x)
408428
else:

0 commit comments

Comments
 (0)