Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dl_bench/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(self, params) -> None:

batch_size = int(params.get("batch_size", 1024))

min_batches = 10
DATASET_SIZE = max(10_240, batch_size * min_batches)
min_batches = 20
DATASET_SIZE = max(102_400, batch_size * min_batches)
dataset = RandomInfDataset(DATASET_SIZE, in_shape)

name = params.get("name", "size5")
Expand All @@ -92,5 +92,5 @@ def __init__(self, params) -> None:

super().__init__(
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size,\
min_batches=min_batches, min_seconds=min_seconds
min_batches=min_batches, min_seconds=min_seconds, warmup_batches=10
)
30 changes: 9 additions & 21 deletions dl_bench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,13 @@ def _get_device(device_name):

class Benchmark:
def __init__(
self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10
self, net, in_shape, dataset, batch_size, min_batches=10, min_seconds=10, warmup_batches=3,
) -> None:
self.net = net
self.in_shape = in_shape
self.dataset = dataset
self.batch_size = batch_size
self.warmup_batches = warmup_batches
self.min_batches = min_batches
self.min_seconds = min_seconds

Expand Down Expand Up @@ -379,24 +380,6 @@ def inference(self, backend: Backend):
sample = next(iter(test_loader))
self.compile(sample, backend)

print("Warmup started")
with torch.no_grad(), tm.timeit("warmup_s"):
self.net.eval()
sample = backend.to_device(sample)
if backend.dtype != torch.float32:
with torch.autocast(
device_type=backend.device_name,
dtype=backend.dtype,
):
self.net(sample)
self.net(sample)
self.net(sample)
else:
self.net(sample)
self.net(sample)
self.net(sample)
print("Warmup done")

n_items = 0

self.net.eval()
Expand All @@ -417,15 +400,19 @@ def inference(self, backend: Backend):
y = self.net(x)
else:
y = self.net(x)
if i < 3: continue

if i < self.warmup_batches:
start = time.perf_counter()
continue

fw_times.append(get_time() - s)
n_items += len(x)
outputs.append(y)

# early stopping if we have 10+ batches and were running for 10+ seconds
if (
(time.perf_counter() - start) > self.min_seconds
and n_items > self.batch_size * self.min_batches
and n_items >= self.batch_size * self.min_batches
):
break

Expand All @@ -437,6 +424,7 @@ def inference(self, backend: Backend):
)

results = tm.get_results()
results["duration_s"] = get_time() - start
results["samples_per_s"] = n_items / sum(fw_times)
results["flops_per_sample"] = self.flops_per_sample

Expand Down