Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 6ecf78b

Browse files
authored
Refactored warmup, increased dataset size for MLP (#78)
1 parent 01563b7 commit 6ecf78b

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

dl_bench/mlp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def __init__(self, params) -> None:
8181

8282
batch_size = int(params.get("batch_size", 1024))
8383

84-
min_batches = 10
85-
DATASET_SIZE = max(10_240, batch_size * min_batches)
84+
min_batches = 20
85+
DATASET_SIZE = max(102_400, batch_size * min_batches)
8686
dataset = RandomInfDataset(DATASET_SIZE, in_shape)
8787

8888
name = params.get("name", "size5")
@@ -92,5 +92,5 @@ def __init__(self, params) -> None:
9292

9393
super().__init__(
9494
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size,\
95-
min_batches=min_batches, min_seconds=min_seconds
95+
min_batches=min_batches, min_seconds=min_seconds, warmup_batches=10
9696
)

dl_bench/utils.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -343,12 +343,13 @@ def _get_device(device_name):
343343

344344
class Benchmark:
345345
def __init__(
346-
self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10
346+
self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10, warmup_batches=3,
347347
) -> None:
348348
self.net = net
349349
self.in_shape = in_shape
350350
self.dataset = dataset
351351
self.batch_size = batch_size
352+
self.warmup_batches = warmup_batches
352353
self.min_batches = min_batches
353354
self.min_seconds = min_seconds
354355

@@ -379,24 +380,6 @@ def inference(self, backend: Backend):
379380
sample = next(iter(test_loader))
380381
self.compile(sample, backend)
381382

382-
print("Warmup started")
383-
with torch.no_grad(), tm.timeit("warmup_s"):
384-
self.net.eval()
385-
sample = backend.to_device(sample)
386-
if backend.dtype != torch.float32:
387-
with torch.autocast(
388-
device_type=backend.device_name,
389-
dtype=backend.dtype,
390-
):
391-
self.net(sample)
392-
self.net(sample)
393-
self.net(sample)
394-
else:
395-
self.net(sample)
396-
self.net(sample)
397-
self.net(sample)
398-
print("Warmup done")
399-
400383
n_items = 0
401384

402385
self.net.eval()
@@ -417,15 +400,19 @@ def inference(self, backend: Backend):
417400
y = self.net(x)
418401
else:
419402
y = self.net(x)
420-
if i < 3: continue
403+
404+
if i < self.warmup_batches:
405+
start = time.perf_counter()
406+
continue
407+
421408
fw_times.append(get_time() - s)
422409
n_items += len(x)
423410
outputs.append(y)
424411

425412
# early stopping if we have 10+ batches and were running for 10+ seconds
426413
if (
427414
(time.perf_counter() - start) > self.min_seconds
428-
and n_items > self.batch_size * self.min_batches
415+
and n_items >= self.batch_size * self.min_batches
429416
):
430417
break
431418

@@ -437,6 +424,7 @@ def inference(self, backend: Backend):
437424
)
438425

439426
results = tm.get_results()
427+
results["duration_s"] = get_time() - start
440428
results["samples_per_s"] = n_items / sum(fw_times)
441429
results["flops_per_sample"] = self.flops_per_sample
442430

0 commit comments

Comments
 (0)