-
Notifications
You must be signed in to change notification settings - Fork 584
Description
1. Problem Description
When frameworks like vLLM invoke FlashInfer's mm_fp4 (backend="cutlass") or fp8_gemm_sm100 interfaces, a continuous increase in memory usage is observed, potentially leading to OOM (Out Of Memory) errors.
Screenshot below is based on version 0.3.1
2. Root Cause Analysis
The root cause of the memory leak lies in the incompatibility between the AutoTuner caching mechanism and the dynamic creation of TuningConfig objects.
2.1 Call Chain
-
AutoTunerCaching Mechanism:
Theflashinfer.autotuner.AutoTunerclass has a method_find_nearest_profiledecorated with@lru_cache(maxsize=None):@classmethod @lru_cache(maxsize=None) def _find_nearest_profile( cls, shapes: Tuple[torch.Size], tuning_config: TuningConfig ) -> Tuple:
This method uses
(shapes, tuning_config)as the cache key. -
Dynamic Creation of
TuningConfig:
In the current code,TuningConfigobjects are created anew inside themm_fp4andfp8_gemm_sm100functions every time they are called:# Example code def mm_fp4(...): # ... tuning_config = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec(...), ), constraint_specs=( ConstraintSpec(..., lambda shapes: ...), # <--- New lambda created each time ), ) runner, tactic = tuner.choose_one( "fp4_gemm", runners, tuning_config, inputs, )
-
Inconsistent Hash Values:
TuningConfigcontainsConstraintSpec.ConstraintSpeccontains alambdafunction.- In Python, executing a
lambdaexpression creates a new function object each time, even if the code logic is identical. These function objects have different memory addresses (and thus different hash values). - Consequently, the
TuningConfigobject generated in each call tomm_fp4is considered "unequal".
-
Unbounded Cache Growth:
Since thetuning_configargument passed to_find_nearest_profileis different every time,lru_cachefails to hit existing entries and adds new entries to the cache instead. Withmaxsize=None, the cache grows indefinitely, causing a memory leak.
3. Affected Interfaces
flashinfer.gemm.mm_fp4(whenbackend="cutlass")flashinfer.gemm.fp8_gemm_sm100(FP8 GEMM implementation on SM100+ architectures)
4. Verification Method
A minimal reproducible example (reproduce_leak.py) is provided to simulate this behavior.
`reproduce_leak.py`
import functools
from dataclasses import dataclass, field
from typing import Tuple, Callable, Union, List
import sys
# --- Mocking FlashInfer Classes ---
@dataclass(unsafe_hash=True)
class ConstraintSpec:
input_idx: int
dim_idx: int
infer_shape: Callable
@dataclass(unsafe_hash=True)
class TuningConfig:
constraint_specs: Tuple[ConstraintSpec, ...] = ()
# --- Mocking AutoTuner ---
class AutoTuner:
_instance = None
@classmethod
def get(cls):
if cls._instance is None:
cls._instance = AutoTuner()
return cls._instance
# This mimics the behavior of AutoTuner._find_nearest_profile
# The critical part is @functools.lru_cache(maxsize=None)
@functools.lru_cache(maxsize=None)
def _find_nearest_profile(self, shapes, tuning_config):
# The logic inside doesn't matter for the leak,
# what matters is that (shapes, tuning_config) are used as the cache key.
return shapes
def choose_one(self, tuning_config, inputs):
# Simulate extracting shapes
input_shapes = tuple(inputs)
# Call the cached method
self._find_nearest_profile(input_shapes, tuning_config)
# Return current cache size for monitoring
return self._find_nearest_profile.cache_info().currsize
# --- Simulation ---
def run_simulation(iterations=100, fix=False):
tuner = AutoTuner.get()
# Clear cache to start fresh
tuner._find_nearest_profile.cache_clear()
print(f"Running simulation with fix={fix} for {iterations} iterations...")
# Fixed version: Create config ONCE outside the loop
# This simulates defining the config at module level or outside the function call
fixed_lambda = lambda x: x
fixed_config = TuningConfig(
constraint_specs=(
ConstraintSpec(0, 0, fixed_lambda),
)
)
for i in range(iterations):
if fix:
# FIX: Reuse the same config object
config = fixed_config
else:
# LEAK: Create a NEW config object with a NEW lambda every time
# This simulates the bug in cutlass_fp4_gemm where TuningConfig was created inside the function
config = TuningConfig(
constraint_specs=(
ConstraintSpec(0, 0, lambda x: x), # <--- New lambda function created here!
)
)
# Simulate constant input shapes (e.g., batch size 1, seq len 128)
inputs = ((1, 128), (128, 128))
cache_size = tuner.choose_one(config, inputs)
if i % 20 == 0:
print(f" Iteration {i:3d}: Cache size = {cache_size}")
final_size = tuner._find_nearest_profile.cache_info().currsize
print(f"Final Cache size: {final_size}")
if not fix and final_size == iterations:
print(" -> LEAK CONFIRMED: Cache size equals number of iterations.")
elif fix and final_size == 1:
print(" -> FIX VERIFIED: Cache size remains constant (1).")
else:
print(" -> Unexpected result.")
if __name__ == "__main__":
print("=== Scenario 1: Reproducing the Memory Leak ===")
run_simulation(iterations=100, fix=False)
print("\n" + "="*40 + "\n")
print("=== Scenario 2: Verifying the Fix ===")
run_simulation(iterations=100, fix=True)How to run:
python reproduce_leak.pyOutput:
=== Scenario 1: Reproducing the Memory Leak ===
Running simulation with fix=False for 100 iterations...
Iteration 0: Cache size = 1
Iteration 20: Cache size = 21
Iteration 40: Cache size = 41
Iteration 60: Cache size = 61
Iteration 80: Cache size = 81
Final Cache size: 100
-> LEAK CONFIRMED: Cache size equals number of iterations.
========================================
=== Scenario 2: Verifying the Fix ===
Running simulation with fix=True for 100 iterations...
Iteration 0: Cache size = 1
Iteration 20: Cache size = 1
Iteration 40: Cache size = 1
Iteration 60: Cache size = 1
Iteration 80: Cache size = 1
Final Cache size: 1
-> FIX VERIFIED: Cache size remains constant (1).
