Skip to content

Commit d0d99d2

Browse files
juju812He Junyzh119
authored
Use global TuningConfig, to fix memory leak caused by AutoTuner LRU cache and dynamic lambda TuningConfig (#2140)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> This PR is to fix a memory leak bug caused by AutoTuner LRU cache and dynamic lambda TuningConfig ## 🔍 Related Issues <!-- Link any related issues here --> #2139 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Performance** * Reduced autotuner overhead by caching runner parameter names to avoid repeated signature inspection during profiling, speeding up tuning runs. * **New Features** * Centralized reusable tuning presets for mixed-precision GEMM (FP8/FP4) with additional tuning presets to improve autotuning and execution efficiency. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: He Jun <hejun01@netease.com> Co-authored-by: yzh119 <zihaoy@nvidia.com>
1 parent 1940b28 commit d0d99d2

File tree

2 files changed

+82
-49
lines changed

2 files changed

+82
-49
lines changed

flashinfer/autotuner.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,13 @@ def choose_one(
458458
# Record the total configs to try
459459
self.stats.tuned_op_total_configs[custom_op] = len(profiles)
460460

461+
# Pre-compute runner arg names to avoid calling inspect.signature in the loop
462+
runner_arg_names_map = {}
463+
for r in runners:
464+
runner_arg_names_map[r] = {
465+
param.name for param in inspect.signature(r.forward).parameters.values()
466+
}
467+
461468
for p in profiles:
462469
tensors = self._prepare_input_tensors(p, inputs)
463470
is_cache_hit, runner_id, tactic, _ = self.search_cache(
@@ -470,9 +477,7 @@ def choose_one(
470477
for r_id, r in enumerate(runners):
471478
# TODO: use FakeTensor here.
472479
valid_tactics = r.get_valid_tactics(tensors, p)
473-
runner_arg_names = {
474-
p.name for p in inspect.signature(r.forward).parameters.values()
475-
}
480+
runner_arg_names = runner_arg_names_map[r]
476481
if "do_preparation" in runner_arg_names and len(valid_tactics) > 0:
477482
r(tensors, tactic=-1, do_preparation=True, **kwargs)
478483
for tac in valid_tactics:

flashinfer/gemm/gemm_base.py

Lines changed: 74 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,25 @@ def forward(
356356
)
357357

358358

359+
_FP8_GEMM_SM100_TUNING_CONFIG = TuningConfig(
360+
dynamic_tensor_specs=(
361+
DynamicTensorSpec(
362+
(0,), # a_tensor_index
363+
(-2,),
364+
get_last_power_of_2_num_tokens_buckets,
365+
last_positive_power_of_2,
366+
),
367+
),
368+
constraint_specs=(
369+
ConstraintSpec(
370+
4, # out_tensor_index
371+
-2,
372+
lambda shapes: shapes[0][-2],
373+
),
374+
),
375+
)
376+
377+
359378
def fp8_gemm_sm100(
360379
a: torch.Tensor,
361380
b: torch.Tensor,
@@ -376,29 +395,12 @@ def fp8_gemm_sm100(
376395
runners.append(_cudnn_gemm_fp8_runner())
377396
assert runners, "No suitable runners found"
378397
tuner = AutoTuner.get()
379-
a_tensor_index = 0
380-
out_tensor_index = 4
381-
tuning_config = TuningConfig(
382-
dynamic_tensor_specs=(
383-
DynamicTensorSpec(
384-
(a_tensor_index,),
385-
(-2,),
386-
get_last_power_of_2_num_tokens_buckets,
387-
last_positive_power_of_2,
388-
),
389-
),
390-
constraint_specs=(
391-
ConstraintSpec(
392-
out_tensor_index, -2, lambda shapes: shapes[a_tensor_index][-2]
393-
),
394-
),
395-
)
396398

397399
inputs = [a, b, scale_a, scale_b, out, workspace_buffer]
398400
runner, tactic = tuner.choose_one(
399401
"fp8_gemm",
400402
runners,
401-
tuning_config,
403+
_FP8_GEMM_SM100_TUNING_CONFIG,
402404
inputs,
403405
)
404406

@@ -2019,6 +2021,58 @@ def _heuristic_func_mm_fp4(
20192021
return [c for c in candidate_backends if c in suitable_backends]
20202022

20212023

2024+
def _pad_up(x, y):
2025+
return ((x + y - 1) // y) * y
2026+
2027+
2028+
_MM_FP4_TUNING_CONFIG_8x4 = TuningConfig(
2029+
dynamic_tensor_specs=(
2030+
DynamicTensorSpec(
2031+
(0,), # a_tensor_index
2032+
(0,),
2033+
get_last_power_of_2_num_tokens_buckets,
2034+
last_positive_power_of_2,
2035+
),
2036+
),
2037+
constraint_specs=(
2038+
ConstraintSpec(
2039+
2, # a_scale_tensor_index
2040+
0,
2041+
lambda shapes: _pad_up(shapes[0][0], 8),
2042+
),
2043+
ConstraintSpec(
2044+
6, # out_tensor_index
2045+
0,
2046+
lambda shapes: shapes[0][0],
2047+
),
2048+
),
2049+
)
2050+
2051+
2052+
_MM_FP4_TUNING_CONFIG_128x4 = TuningConfig(
2053+
dynamic_tensor_specs=(
2054+
DynamicTensorSpec(
2055+
(0,), # a_tensor_index
2056+
(0,),
2057+
get_last_power_of_2_num_tokens_buckets,
2058+
last_positive_power_of_2,
2059+
),
2060+
),
2061+
constraint_specs=(
2062+
ConstraintSpec(
2063+
2, # a_scale_tensor_index
2064+
0,
2065+
lambda shapes: _pad_up(shapes[0][0], 128),
2066+
),
2067+
ConstraintSpec(
2068+
6, # out_tensor_index
2069+
0,
2070+
lambda shapes: shapes[0][0],
2071+
),
2072+
),
2073+
)
2074+
2075+
20222076
@backend_requirement(
20232077
{
20242078
"cudnn": _cudnn_gemm_fp4_requirement,
@@ -2138,34 +2192,8 @@ def mm_fp4(
21382192
# Now we have a list of runners for desired & supported backends.
21392193
tuner = AutoTuner.get()
21402194

2141-
a_tensor_index = 0
2142-
a_scale_tensor_index = 2
2143-
out_tensor_index = 6
2144-
2145-
def pad_up(x, y):
2146-
return ((x + y - 1) // y) * y
2147-
2148-
tuning_config = TuningConfig(
2149-
dynamic_tensor_specs=(
2150-
DynamicTensorSpec(
2151-
(a_tensor_index,),
2152-
(0,),
2153-
get_last_power_of_2_num_tokens_buckets,
2154-
last_positive_power_of_2,
2155-
),
2156-
),
2157-
constraint_specs=(
2158-
ConstraintSpec(
2159-
a_scale_tensor_index,
2160-
0,
2161-
lambda shapes: pad_up(
2162-
shapes[a_tensor_index][0], 8 if use_8x4_sf_layout else 128
2163-
),
2164-
),
2165-
ConstraintSpec(
2166-
out_tensor_index, 0, lambda shapes: shapes[a_tensor_index][0]
2167-
),
2168-
),
2195+
tuning_config = (
2196+
_MM_FP4_TUNING_CONFIG_8x4 if use_8x4_sf_layout else _MM_FP4_TUNING_CONFIG_128x4
21692197
)
21702198

21712199
inputs = [

0 commit comments

Comments
 (0)