Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 700a97c

Browse files
Ivy ZhangEgor-Krivov
authored andcommitted
skip 3 warmup steps in benchmarking (#75)
1 parent 1e77319 commit 700a97c

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

dl_bench/mlp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def __init__(self, params) -> None:
8787

8888
name = params.get("name", "size5")
8989
net = get_mlp(n_chans_in=IN_FEAT, n_chans_out=N_CLASSES, name=name)
90+
min_batches = int(params.get("min_batches", 10))
91+
min_seconds = int(params.get("min_seconds", 10))
9092

9193
super().__init__(
92-
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size
94+
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size,\
95+
min_batches=min_batches, min_seconds=min_seconds
9396
)

dl_bench/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def str_to_dtype(dtype: str):
121121
return torch.float32
122122
elif dtype == "bfloat16":
123123
return torch.bfloat16
124+
elif dtype == "int8":
125+
return torch.qint8
124126
else:
125127
raise ValueError(f"Unsupported data type: {dtype}")
126128

@@ -415,7 +417,7 @@ def inference(self, backend: Backend):
415417
y = self.net(x)
416418
else:
417419
y = self.net(x)
418-
420+
if i < 3: continue
419421
fw_times.append(get_time() - s)
420422
n_items += len(x)
421423
outputs.append(y)

0 commit comments

Comments
 (0)