Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions vllm/lora/ops/triton_ops/README_TUNING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Multi-LoRA Tuning
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a description of how to do tuning, so that users can perform tuning by reading this document, we can refer to https://github.com/sgl-project/sglang/blob/main/benchmark/kernels/fused_moe_triton/README.md#tuning-tool

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't submit my comment. After responding to the above comment, we can consider merging this PR.


**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`. Without this, the shrink/expand kernels will use default configurations.

## Tuning Process

Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from [Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).

**Step 1**
Define the searching space. An example searching space:

```python
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128, 256]
num_warps_range = [4, 8]
num_stage_range = [2, 3, 4, 5]
num_ctas_range = [1]
split_k_range = [4, 8, 16, 32, 64]
```

**Step 2**
Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.

For example, we can aquire those info by simply checking [add_lora_linear](https://github.com/li2haipeng/vllm/blob/multi_lora_v01011/vllm/lora/punica_wrapper/punica_gpu.py#L192):

```python
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
print(f"num_sclises: {len(output_slices)}")
for i in range(len(output_slices)):
print(f"a{i} shape: {lora_a_stacked[i].shape}")
print(f"b{i} shape: {lora_b_stacked[i].shape}")
print("y_shape", y.shape)
```

**Step 3**
Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space by performing a grid search to find the optimal kernel configuration. vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py) can be used to search for configurations for different shapes.

## Config Files

### File Name

For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA_H200_SHRINK.json`.

For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.

The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`

### Json Structure

Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`
23 changes: 16 additions & 7 deletions vllm/lora/ops/triton_ops/lora_expand_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -201,12 +201,21 @@ def _lora_expand(
NUM_SLICES = len(lora_b_weights)

# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
kernel_config = get_lora_op_configs(
op_type="expand",
max_loras=MAX_LORAS,
batch=M,
hidden_size=MAX_N,
rank=K,
num_slices=NUM_SLICES,
add_inputs=add_inputs,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_CTAS = kernel_config["num_ctas"]
NUM_STAGES = kernel_config["num_stages"]

EVEN_K = K % BLOCK_K == 0 # type: ignore

Expand Down
25 changes: 16 additions & 9 deletions vllm/lora/ops/triton_ops/lora_shrink_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch

from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -177,14 +177,21 @@ def _lora_shrink(
MAX_LORAS = lora_ids.size(0)

# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2

kernel_config = get_lora_op_configs(
"shrink",
max_loras=MAX_LORAS,
batch=M,
hidden_size=K,
rank=N,
num_slices=NUM_SLICES,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
SPLIT_K = kernel_config["split_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore

# TODO (varun): This grid formulation maximizes parallelization at the
Expand Down
115 changes: 115 additions & 0 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import functools
import json
from pathlib import Path
from typing import Any

import torch

from vllm import envs
from vllm.logger import init_logger

logger = init_logger(__name__)

_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}

Expand Down Expand Up @@ -133,3 +143,108 @@ def _get_lora_b_ptr(
MAX_N,
)
return _LORA_B_PTR_DICT.get(key)


@functools.lru_cache
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
if user_defined_config_folder is not None:
gpu_name = torch.cuda.get_device_name()
gpu_name = gpu_name.replace(" ", "_")
gpu_name = gpu_name.replace("-", "_")

config_fname = None
if op_type == "shrink":
config_fname = f"{gpu_name}_{op_type.upper()}.json"
else:
assert op_type == "expand"
config_fname = (
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
)

config_path = Path(f"{user_defined_config_folder}/{config_fname}")
if not config_path.exists():
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
return None

# Load json
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
with open(str(config_path)) as f:
config_data = json.load(f)
else:
config_data = None

return config_data


@functools.lru_cache
def get_lora_op_configs(
op_type: str,
max_loras: int,
batch: int,
hidden_size: int,
rank: int,
num_slices: int,
add_inputs: bool | None = None,
) -> dict[str, int | None]:
assert op_type in ["shrink", "expand"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The assertion assert op_type in ["shrink", "expand"] makes the logic for op_type == "fused" unreachable in the rest of the function (e.g., lines 217-218 and 224-239). This appears to be unintentional, as there is code to handle the fused case. If fused is meant to be supported, this assertion should be updated. Otherwise, the dead code should be removed to avoid confusion and potential bugs.

Suggested change
assert op_type in ["shrink", "expand"]
assert op_type in ["shrink", "expand", "fused"]


# default config
default = {}
if op_type == "shrink":
default = {
"block_m": 32,
"block_n": 16,
"block_k": 256 if batch < 128 else 32,
"split_k": 64 if batch < 128 else 8,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2,
"max_nreg": None,
}
else:
default = {
"block_m": 64,
"block_n": 128,
"block_k": 16,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2,
"max_nreg": None,
}
m = batch

k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)

config_data: Any
config_data = load_lora_op_config(op_type, add_inputs)
if not config_data:
logger.warning_once("Using default LoRA kernel configs")
return default

# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
# slice by max_loras
config_data = (
config_data.get(str(max_loras))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
)
# slice by num_slices
config_data = config_data[str(num_slices)]
# slice by m
config_data = (
config_data.get(str(m))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
)
# slice by k
config_data = (
config_data.get(str(k))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
)
# slice by n
config_data = (
config_data.get(str(n))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
)

assert config_data is not None
return config_data