Skip to content
This repository was archived by the owner on Jul 24, 2024. It is now read-only.

Commit 5dee994

Browse files
authored
Enabled verification with inconsistent number of batches (#51)
1 parent 111181b commit 5dee994

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

dl_bench/cli/launcher.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@
1717
}
1818

1919

20+
def fix_lengths(outputs, ref_outputs):
21+
"""To speed up benchmarking we pass different number of batches for different backends.
22+
Need to match the lenghts."""
23+
min_lengths = min(len(outputs), len(ref_outputs))
24+
if len(outputs) != len(ref_outputs):
25+
print(
26+
f"Slicing passed batches to smallest size {len(outputs)}->{min_lengths}; {len(ref_outputs)}->{min_lengths}"
27+
)
28+
return outputs[:min_lengths], ref_outputs[:min_lengths]
29+
else:
30+
return outputs, ref_outputs
31+
32+
2033
def parse_args():
2134
parser = argparse.ArgumentParser()
2235
# Benchmark
@@ -142,6 +155,7 @@ def main():
142155
reference_backend = Backend(device=ref_device, compiler="torch", dtype=dtype)
143156
_, ref_outputs = benchmark.inference(reference_backend)
144157
results, outputs = benchmark.inference(backend)
158+
outputs, ref_outputs = fix_lengths(outputs, ref_outputs)
145159
cmp_res = compare(outputs, ref_outputs)
146160

147161
print(f"Benchmark {benchmark_name} completed")

0 commit comments

Comments
 (0)