File tree Expand file tree Collapse file tree 1 file changed +15
-6
lines changed Expand file tree Collapse file tree 1 file changed +15
-6
lines changed Original file line number Diff line number Diff line change @@ -300,18 +300,27 @@ def estimate_strategy_runtime_cost(node, strategy):
300300 gpu_memory_bandwidth = _get_device_gmem_bandwidth ()
301301 read_write_time = read_write_bytes / gpu_memory_bandwidth * 1e6 # us
302302
303+ # suppose 70% efficiency for the operator
304+ read_write_efficiency = 0.70
305+
306+ kernel_launch_overhead = 7 # us
307+
308+ read_write_time = max (
309+ read_write_time / read_write_efficiency , kernel_launch_overhead
310+ )
311+
312+ if flops == 0 :
313+ return read_write_time
303314 # TODO: fix this
304315 dtype = strategy .input_specs [0 ].tensor_meta .dtype
305316
306- # TODO: better handle this case
307- if dtype .is_complex :
308- return read_write_time
309317 # TODO: use PyTorch's version once it's giving correct results
310318 gpu_flops = _get_device_tflops (dtype ) * 10 ** 12
311319
312- # suppose 50% efficiency for the operator
313- factor = 1 / 0.5
314- compute_time = factor * flops / gpu_flops * 1e6 # us
320+ # suppose 70% efficiency for the operator
321+ compute_efficiency = 0.70
322+ compute_time = flops / gpu_flops * 1e6 # us
323+ compute_time = max (compute_time / compute_efficiency , kernel_launch_overhead )
315324
316325 return max (compute_time , read_write_time )
317326
You can’t perform that action at this time.
0 commit comments