Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 55e0d4b

Browse files
author
Zhang Yan
committed
modify benchmark to make early stop work
1 parent 46fe63a commit 55e0d4b

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

dl_bench/mlp.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def __init__(self, params) -> None:
8787

8888
name = params.get("name", "size5")
8989
net = get_mlp(n_chans_in=IN_FEAT, n_chans_out=N_CLASSES, name=name)
90+
min_batches = int(params.get("min_batches", 10))
91+
min_seconds = int(params.get("min_seconds", 10))
9092

9193
super().__init__(
92-
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size
94+
net=net, in_shape=in_shape, dataset=dataset, batch_size=batch_size,\
95+
min_batches=min_batches, min_seconds=min_seconds
9396
)

dl_bench/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ def str_to_dtype(dtype: str):
121121
return torch.float32
122122
elif dtype == "bfloat16":
123123
return torch.bfloat16
124+
elif dtype == "int8":
125+
return torch.qint8
124126
else:
125127
raise ValueError(f"Unsupported data type: {dtype}")
126128

@@ -404,9 +406,9 @@ def inference(self, backend: Backend):
404406
start = time.perf_counter()
405407
# Duration is inconsistent now
406408
with tm.timeit("duration_s"):
407-
for i, x in enumerate(test_loader):
409+
while True:
408410
s = get_time()
409-
x = backend.to_device(x)
411+
x = backend.to_device(sample)
410412
if backend.dtype != torch.float32:
411413
with torch.autocast(
412414
device_type=backend.device_name,

0 commit comments

Comments
 (0)