Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support dynamic loading and unloading of Lora adapters #2891

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(self, uid, config, base_hf_config, load_config):
self.weights = {}
self.weights_gpu = {}

@classmethod
def get_stacked_multiply(self, module_name):
stacked_rank = {
"qkv_proj": 3,
Expand Down
131 changes: 124 additions & 7 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ def set_lora_module(self, module_name, module):
def init_loras(self):
# get configs and target modules
self.configs = {}
self.origin_target_modules = set()
origin_target_modules = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.origin_target_modules = set(self.origin_target_modules) | set(
origin_target_modules = set(origin_target_modules) | set(
self.configs[name].target_modules
)
if hasattr(self.base_model, "get_module_name"):
self.target_modules = {
self.base_model.get_module_name(module)
for module in self.origin_target_modules
for module in origin_target_modules
}
else:
logger.warning(
Expand All @@ -149,10 +149,10 @@ def init_loras(self):
"Use the default one, but please check if it is correct for your model."
)
self.target_modules = {
get_module_name(module) for module in self.origin_target_modules
get_module_name(module) for module in origin_target_modules
}
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
[get_stacked_name(module) for module in origin_target_modules]
)

# load all weights to cpu
Expand Down Expand Up @@ -197,7 +197,7 @@ def init_lora_memory_pool(self):
"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_A)
c = LoRAAdapter.get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
Expand All @@ -221,7 +221,7 @@ def init_lora_memory_pool(self):
"Use the default one, but please check if it is correct for your model."
)
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_B)
c = LoRAAdapter.get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
Expand Down Expand Up @@ -327,3 +327,120 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
seg_indptr,
weight_indices,
)

def load_lora_adapter(self, lora_name: str, lora_path: str):
if lora_name in self.lora_paths:
error_msg = f"There has been a Lora adapter named as {lora_name}! Please change for another name."
logger.error(error_msg)
return False, error_msg

# Load lora config. Return false if the loading of config is failed.
try:
self.configs[lora_name] = LoRAConfig(lora_path)
except Exception as e:
error_msg = (
f"Failed to load lora config from path {lora_path}.
Error Log: {e}. "
)
logger.error(error_msg)
return False, error_msg

# Update target modules.
added_target_modules = set(self.configs[lora_name].target_modules)
if hasattr(self.base_model, "get_module_name"):
self.target_modules = set(self.target_modules) | {
self.base_model.get_module_name(module)
for module in added_target_modules
}
else:
logger.warning(
"WARNING: get_module_name() is not defined, "
"which is used to map config module name to model implementation module name."
"Use the default one, but please check if it is correct for your model."
)
self.target_modules = set(self.target_modules) | {
get_module_name(module) for module in added_target_modules
}
self.target_weights = set(self.target_weights) | set(
[get_stacked_name(module) for module in added_target_modules]
)

# Load adapter weights to cpu.
self.lora_id[lora_name] = len(self.loras)
lora_adapter = LoRAAdapter(lora_name, self.configs[lora_name], self.base_hf_config, self.load_config)
self.loras.append(lora_adapter)
lora_adapter.initialize_weights()

# Check the lora adapter matches the demand of lora_dim and scaling.
# FIXME: Should be removed after supporting multi-rank loras.
if (self.max_lora_dim != self.configs[lora_name].r) or (self.scaling != lora_adapter.scaling):
error_msg = (
f"Currently sglang only supports serving of lora adapters with identical lora ranks"
f"and scaling. Newly loaded adapter has rank={self.configs[lora_name].r} and scaling={lora_adapter.scaling},"
f"but running adapters has rank={self.max_lora_dim} and scaling={self.scaling}. Please make sure they are equal!"
)
logger.error(error_msg)
return False, error_msg

# Replace with Lora module if there are new modules that match target.
processed_module_names = set(lora_module[0] for lora_module in self.lora_modules)
for module_name, module in self.base_model.named_modules():
if self.match_target_modules(module_name) and module_name not in processed_module_names:
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)

# Create new spaces to memory buffer if there are newly added target modules.
# TODO(Baizhou): This piece of logic can be reused in a helper function.
num_layer = self.base_hf_config.num_hidden_layers
for module_A, module_B in self.target_weights:
if module_A not in self.A_buffer:
if hasattr(self.base_model, "get_hidden_dim"):
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
else:
logger.warning(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
c = LoRAAdapter.get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
hidden_dim_A,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]

if module_B not in self.B_buffer:
if hasattr(self.base_model, "get_hidden_dim"):
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
else:
logger.warning(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
c = LoRAAdapter.get_stacked_multiply(module_B)
self.B_buffer[module_B] = [
torch.empty(
(
self.max_loras_per_batch,
hidden_dim_B * c,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]

return True, "Success"
26 changes: 26 additions & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,32 @@ class GetWeightsByNameReqOutput:
parameter: list


@dataclass
class LoadLoRAAdapterReqInput:
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
lora_path: str


@dataclass
class LoadLoRAAdapterReqOutput:
success: bool
message: str


@dataclass
class UnLoadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str


@dataclass
class UnLoadLoRAAdapterReqOutput:
success: bool
message: str


@dataclass
class AbortReq:
# The request id
Expand Down
27 changes: 27 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
Expand Down Expand Up @@ -516,6 +518,11 @@ def process_input_requests(self, recv_reqs: List):
elif isinstance(recv_req, GetWeightsByNameReqInput):
parameter = self.get_weights_by_name(recv_req)
self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
elif isinstance(recv_req, LoadLoRAAdapterReqInput):
success, message = self.load_lora_adapter(recv_req)
self.send_to_tokenizer.send_pyobj(
LoadLoRAAdapterReqOutput(success, message)
)
elif isinstance(recv_req, ProfileReq):
if recv_req == ProfileReq.START_PROFILE:
self.start_profile()
Expand Down Expand Up @@ -1543,6 +1550,26 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return parameter

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
"""In-place loading a new lora adapater from disk or huggingface."""

if (not self.server_args.disable_radix_cache) or (
not self.server_args.disable_cuda_graph
):
success, message = (
False,
"Radix cache or cuda graph not supported when Lora is enabled, please try turning it off.",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder why radix cache cannot work with LoRA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently radix cache is not compatible with Lora. In the future there is a plan to support their compatibility as listed in #2929.

)
else:
success, message = self.tp_worker.load_lora_adapter(recv_req)

if success:
flash_cache_success = self.flush_cache()
assert flash_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
return success, message

def start_profile(self) -> None:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
Expand Down
51 changes: 51 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
Expand Down Expand Up @@ -169,6 +171,8 @@ def __init__(
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
None
)
self.lora_update_lock = RWLock()
self.lora_update_result: Optional[Awaitable[LoadLoRAAdapterReqOutput]] = None
self.asyncio_tasks = set()

# For session info
Expand Down Expand Up @@ -548,6 +552,44 @@ async def get_weights_by_name(
else:
return all_parameters

async def load_lora_adapter(
self,
obj: LoadLoRAAdapterReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()

# default the load format to the server_args
if obj.load_format is None:
obj.load_format = self.server_args.load_format
logger.info("Start load Lora weights. Load format=%s", obj.load_format)

if True:
# Hold the lock if it is not async. This means that lora loading
# cannot begin until requests have been processed.
async with self.lora_update_lock.writer_lock:
return await self._wait_for_lora_loading(obj)

async def _wait_for_lora_loading(
self, obj: LoadLoRAAdapterReqInput
) -> Tuple[bool, str]:
self.send_to_scheduler.send_pyobj(obj)
self.lora_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.lora_update_result
if result.success:
self.server_args.lora_paths[obj.lora_name] = obj.lora_path
return result.success, result.message
else: # self.server_args.dp_size > 1
self.lora_update_tmp_result = []
result = await self.lora_update_result
all_success = all([r.success for r in result])
if all_success:
self.server_args.lora_paths[obj.lora_name] = obj.lora_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message

async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
Expand Down Expand Up @@ -627,6 +669,7 @@ async def handle_loop(self):
UpdateWeightsFromDistributedReqOutput,
GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqOutput,
] = await self.recv_from_detokenizer.recv_pyobj()

if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
Expand Down Expand Up @@ -750,6 +793,14 @@ async def handle_loop(self):
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
self.get_weights_by_name_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, LoadLoRAAdapterReqOutput):
if self.server_args.dp_size == 1:
self.lora_update_result.set_result(recv_obj)
else:
self.lora_update_tmp_result.append(recv_obj)
# set future if the all results are recevied
if len(self.lora_update_tmp_result) == self.server_args.dp_size:
self.lora_update_result.set_result(self.lora_update_tmp_result)
else:
raise ValueError(f"Invalid object: {recv_obj=}")

Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
Expand Down Expand Up @@ -212,3 +213,9 @@ def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
recv_req.name, recv_req.truncate_size
)
return parameter

def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
success, message = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
return success, message
Loading