Skip to content

Commit bb3252b

Browse files
authored
Use better cost model for compute (#146)
Those values were empirically obtained by minimizing the runtime over a set of examples on H100 GPUs. I've rounded the values for simplicity
1 parent 1232662 commit bb3252b

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

autoparallel/compute_estimation.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,27 @@ def estimate_strategy_runtime_cost(node, strategy):
300300
gpu_memory_bandwidth = _get_device_gmem_bandwidth()
301301
read_write_time = read_write_bytes / gpu_memory_bandwidth * 1e6 # us
302302

303+
# suppose 70% efficiency for the operator
304+
read_write_efficiency = 0.70
305+
306+
kernel_launch_overhead = 7 # us
307+
308+
read_write_time = max(
309+
read_write_time / read_write_efficiency, kernel_launch_overhead
310+
)
311+
312+
if flops == 0:
313+
return read_write_time
303314
# TODO: fix this
304315
dtype = strategy.input_specs[0].tensor_meta.dtype
305316

306-
# TODO: better handle this case
307-
if dtype.is_complex:
308-
return read_write_time
309317
# TODO: use PyTorch's version once it's giving correct results
310318
gpu_flops = _get_device_tflops(dtype) * 10**12
311319

312-
# suppose 50% efficiency for the operator
313-
factor = 1 / 0.5
314-
compute_time = factor * flops / gpu_flops * 1e6 # us
320+
# suppose 70% efficiency for the operator
321+
compute_efficiency = 0.70
322+
compute_time = flops / gpu_flops * 1e6 # us
323+
compute_time = max(compute_time / compute_efficiency, kernel_launch_overhead)
315324

316325
return max(compute_time, read_write_time)
317326

0 commit comments

Comments
 (0)