Skip to content

NVFP4 CPU RAM leak caused by dynamic creation of TuningConfig objects #2139

@juju812

Description

@juju812

1. Problem Description

When frameworks like vLLM invoke FlashInfer's mm_fp4 (backend="cutlass") or fp8_gemm_sm100 interfaces, a continuous increase in memory usage is observed, potentially leading to OOM (Out Of Memory) errors.

Screenshot below is based on version 0.3.1

image

2. Root Cause Analysis

The root cause of the memory leak lies in the incompatibility between the AutoTuner caching mechanism and the dynamic creation of TuningConfig objects.

2.1 Call Chain

  1. AutoTuner Caching Mechanism:
    The flashinfer.autotuner.AutoTuner class has a method _find_nearest_profile decorated with @lru_cache(maxsize=None):

    @classmethod
    @lru_cache(maxsize=None)
    def _find_nearest_profile(
        cls, shapes: Tuple[torch.Size], tuning_config: TuningConfig
    ) -> Tuple:

    This method uses (shapes, tuning_config) as the cache key.

  2. Dynamic Creation of TuningConfig:
    In the current code, TuningConfig objects are created anew inside the mm_fp4 and fp8_gemm_sm100 functions every time they are called:

    # Example code
    def mm_fp4(...):
        # ...
        tuning_config = TuningConfig(
            dynamic_tensor_specs=(
                DynamicTensorSpec(...),
            ),
            constraint_specs=(
                ConstraintSpec(..., lambda shapes: ...), # <--- New lambda created each time
            ),
        )
        runner, tactic = tuner.choose_one(
            "fp4_gemm",
            runners,
            tuning_config,
            inputs,
        )
  3. Inconsistent Hash Values:

    • TuningConfig contains ConstraintSpec.
    • ConstraintSpec contains a lambda function.
    • In Python, executing a lambda expression creates a new function object each time, even if the code logic is identical. These function objects have different memory addresses (and thus different hash values).
    • Consequently, the TuningConfig object generated in each call to mm_fp4 is considered "unequal".
  4. Unbounded Cache Growth:
    Since the tuning_config argument passed to _find_nearest_profile is different every time, lru_cache fails to hit existing entries and adds new entries to the cache instead. With maxsize=None, the cache grows indefinitely, causing a memory leak.

3. Affected Interfaces

  • flashinfer.gemm.mm_fp4 (when backend="cutlass")
  • flashinfer.gemm.fp8_gemm_sm100 (FP8 GEMM implementation on SM100+ architectures)

4. Verification Method

A minimal reproducible example (reproduce_leak.py) is provided to simulate this behavior.

`reproduce_leak.py`
import functools
from dataclasses import dataclass, field
from typing import Tuple, Callable, Union, List
import sys

# --- Mocking FlashInfer Classes ---

@dataclass(unsafe_hash=True)
class ConstraintSpec:
    input_idx: int
    dim_idx: int
    infer_shape: Callable

@dataclass(unsafe_hash=True)
class TuningConfig:
    constraint_specs: Tuple[ConstraintSpec, ...] = ()

# --- Mocking AutoTuner ---

class AutoTuner:
    _instance = None
    
    @classmethod
    def get(cls):
        if cls._instance is None:
            cls._instance = AutoTuner()
        return cls._instance

    # This mimics the behavior of AutoTuner._find_nearest_profile
    # The critical part is @functools.lru_cache(maxsize=None)
    @functools.lru_cache(maxsize=None)
    def _find_nearest_profile(self, shapes, tuning_config):
        # The logic inside doesn't matter for the leak, 
        # what matters is that (shapes, tuning_config) are used as the cache key.
        return shapes

    def choose_one(self, tuning_config, inputs):
        # Simulate extracting shapes
        input_shapes = tuple(inputs)
        
        # Call the cached method
        self._find_nearest_profile(input_shapes, tuning_config)
        
        # Return current cache size for monitoring
        return self._find_nearest_profile.cache_info().currsize

# --- Simulation ---

def run_simulation(iterations=100, fix=False):
    tuner = AutoTuner.get()
    # Clear cache to start fresh
    tuner._find_nearest_profile.cache_clear()
    
    print(f"Running simulation with fix={fix} for {iterations} iterations...")
    
    # Fixed version: Create config ONCE outside the loop
    # This simulates defining the config at module level or outside the function call
    fixed_lambda = lambda x: x
    fixed_config = TuningConfig(
        constraint_specs=(
            ConstraintSpec(0, 0, fixed_lambda),
        )
    )

    for i in range(iterations):
        if fix:
            # FIX: Reuse the same config object
            config = fixed_config
        else:
            # LEAK: Create a NEW config object with a NEW lambda every time
            # This simulates the bug in cutlass_fp4_gemm where TuningConfig was created inside the function
            config = TuningConfig(
                constraint_specs=(
                    ConstraintSpec(0, 0, lambda x: x), # <--- New lambda function created here!
                )
            )
        
        # Simulate constant input shapes (e.g., batch size 1, seq len 128)
        inputs = ((1, 128), (128, 128))
        
        cache_size = tuner.choose_one(config, inputs)
        
        if i % 20 == 0:
            print(f"  Iteration {i:3d}: Cache size = {cache_size}")

    final_size = tuner._find_nearest_profile.cache_info().currsize
    print(f"Final Cache size: {final_size}")
    if not fix and final_size == iterations:
        print("  -> LEAK CONFIRMED: Cache size equals number of iterations.")
    elif fix and final_size == 1:
        print("  -> FIX VERIFIED: Cache size remains constant (1).")
    else:
        print("  -> Unexpected result.")

if __name__ == "__main__":
    print("=== Scenario 1: Reproducing the Memory Leak ===")
    run_simulation(iterations=100, fix=False)
    
    print("\n" + "="*40 + "\n")
    
    print("=== Scenario 2: Verifying the Fix ===")
    run_simulation(iterations=100, fix=True)

How to run:

python reproduce_leak.py

Output:

=== Scenario 1: Reproducing the Memory Leak ===
Running simulation with fix=False for 100 iterations...
  Iteration   0: Cache size = 1
  Iteration  20: Cache size = 21
  Iteration  40: Cache size = 41
  Iteration  60: Cache size = 61
  Iteration  80: Cache size = 81
Final Cache size: 100
  -> LEAK CONFIRMED: Cache size equals number of iterations.

========================================

=== Scenario 2: Verifying the Fix ===
Running simulation with fix=True for 100 iterations...
  Iteration   0: Cache size = 1
  Iteration  20: Cache size = 1
  Iteration  40: Cache size = 1
  Iteration  60: Cache size = 1
  Iteration  80: Cache size = 1
Final Cache size: 1
  -> FIX VERIFIED: Cache size remains constant (1).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions