From 424df6125e87c5788027661f5b00b841aa8f51ed Mon Sep 17 00:00:00 2001 From: SamuraiBUPT <31409163@bupt.edu.cn> Date: Wed, 15 Nov 2023 20:23:42 +0800 Subject: [PATCH] blora LlaMa support #1 --- examples/llama_test_lora.py | 48 +++++++ vllm/engine/arg_utils.py | 29 +++- vllm/engine/async_llm_engine.py | 17 ++- vllm/engine/llm_engine.py | 21 ++- vllm/entrypoints/llm.py | 4 + vllm/model_executor/lora_utils.py | 134 +++++++++++++++++++ vllm/model_executor/model_loader.py | 22 ++- vllm/model_executor/models/llama.py | 93 ++++++++++++- vllm/model_executor/parallel_utils/layers.py | 102 ++++++++++++++ vllm/sampling_params.py | 2 + vllm/worker/worker.py | 5 +- 11 files changed, 470 insertions(+), 7 deletions(-) create mode 100755 examples/llama_test_lora.py mode change 100644 => 100755 vllm/engine/arg_utils.py mode change 100644 => 100755 vllm/engine/llm_engine.py mode change 100644 => 100755 vllm/entrypoints/llm.py create mode 100755 vllm/model_executor/lora_utils.py mode change 100644 => 100755 vllm/model_executor/model_loader.py mode change 100644 => 100755 vllm/model_executor/models/llama.py mode change 100644 => 100755 vllm/model_executor/parallel_utils/layers.py mode change 100644 => 100755 vllm/sampling_params.py mode change 100644 => 100755 vllm/worker/worker.py diff --git a/examples/llama_test_lora.py b/examples/llama_test_lora.py new file mode 100755 index 0000000000000..96eb626523a08 --- /dev/null +++ b/examples/llama_test_lora.py @@ -0,0 +1,48 @@ +from vllm import LLM, SamplingParams +import time +if __name__ == "__main__": + prompt = "Hello and welcome, " + prompts = [prompt] + path = "./baichuan2-13b" + path = "/vllm_workspace/weights/llama_7b_hf" + lora_path = "./baichuan2-13b-20231013174626" + lora_path = "/vllm_workspace/weights/alpaca-lora-7b" + lora_path_2 = "./baichuan2-13b-20231013192059" + lora_path_2 = "/vllm_workspace/weights/bactrian-x-llama-7b-lora" + llm = LLM(model=path, + trust_remote_code=True, + lora_paths=[lora_path, lora_path_2], + adapter_names=["adapter_1", "adapter_2"]) + + print(llm.llm_engine.workers[0].model) + + sampling_params = SamplingParams(temperature=0, + top_p=1, + best_of=2, + top_k=-1, + max_tokens=100, + use_beam_search=True, + lora_id="adapter_1") + llm._add_request(prompt=prompt, + prompt_token_ids=None, + sampling_params=sampling_params) + + sampling_params = SamplingParams(temperature=0, + top_p=1, + best_of=2, + top_k=-1, + max_tokens=100, + use_beam_search=True, + lora_id="adapter_2") + llm._add_request(prompt=prompt, + prompt_token_ids=None, + sampling_params=sampling_params) + start = time.time() + outputs = llm._run_engine(use_tqdm=True) + end = time.time() + print(f"cost: {end - start}") + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py old mode 100644 new mode 100755 index cc425a2c079e7..7ba98b2f54aed --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1,7 +1,7 @@ import argparse import dataclasses from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, List from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -32,6 +32,11 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + + # MODIFY + lora_paths: Optional[List[str]] = None + adapter_names: Optional[List[str]] = None + # END def __post_init__(self): if self.tokenizer is None: @@ -171,6 +176,28 @@ def add_cli_args( choices=['awq', 'squeezellm', None], default=None, help='Method used to quantize the weights') + + # MODIFY + parser.add_argument( + '--lora-paths', + metavar='path', + type=str, + default=None, + nargs='+', + help='the paths of lora model you want to load:' + + '[lora_path1 lora_path2 ...]') + + parser.add_argument( + '--adapter-names', + metavar='adapter_name', + type=str, + default=None, + nargs='+', + help='the adapter names of lora model you want to load, each name' + + ' should be unique and needs to correspond to the path ' + + 'one-to-one: [name1 name2 ...]') + # END + return parser @classmethod diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index eabf620cef730..d75518d1becd9 100755 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -490,6 +490,20 @@ def from_engine_args(cls, # Initialize the cluster. distributed_init_method, placement_group = initialize_cluster( parallel_config, engine_args.engine_use_ray) + + # ===================== + # MODIFY HERE + lora_paths: list = engine_args.lora_paths + adapter_names: list = engine_args.adapter_names + lora_configs = None + if lora_paths is not None and adapter_names is not None: + assert len(lora_paths) == len(adapter_names), (len(lora_paths), len(adapter_names)) + lora_configs = [] + for lora_path, adapter_name in zip(lora_paths, adapter_names): + lora_configs.append((lora_path, adapter_name)) + # ===================== + + # Create the async LLM engine. engine = cls(parallel_config.worker_use_ray, engine_args.engine_use_ray, @@ -499,5 +513,6 @@ def from_engine_args(cls, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, max_log_len=engine_args.max_log_len, - start_engine_loop=start_engine_loop) + start_engine_loop=start_engine_loop, + lora_configs=lora_configs) # MODIFY HERE return engine diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py old mode 100644 new mode 100755 index c3752b11f5660..82497437a52f0 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -68,6 +68,7 @@ def __init__( distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, + lora_configs: List[Tuple[str, str]] = None # MODIFY ) -> None: logger.info( "Initializing an LLM engine with config: " @@ -93,6 +94,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats + self.lora_configs = lora_configs # MODIFY self._verify_args() self.tokenizer = get_tokenizer( @@ -137,6 +139,7 @@ def _init_workers(self, distributed_init_method: str): self.scheduler_config, 0, distributed_init_method, + self.lora_configs, # MODIFY ) self.workers.append(worker) self._run_workers( @@ -169,6 +172,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) + lora_configs = copy.deepcopy(self.lora_configs) # MODIFY self._run_workers("init_worker", get_all_outputs=True, worker_init_fn=lambda: Worker( @@ -177,6 +181,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config, None, None, + lora_configs, # MODIFY )) self._run_workers( "init_model", @@ -227,11 +232,25 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": # Initialize the cluster. distributed_init_method, placement_group = initialize_cluster( parallel_config) + + # MODIFY + lora_paths: list = engine_args.lora_paths + adapter_names: list = engine_args.adapter_names + lora_configs = None + if lora_paths is not None and adapter_names is not None: + assert len(lora_paths) == len(adapter_names), (len(lora_paths), + len(adapter_names)) + lora_configs = [] + for lora_path, adapter_name in zip(lora_paths, adapter_names): + lora_configs.append((lora_path, adapter_name)) + # END + # Create the LLM engine. engine = cls(*engine_configs, distributed_init_method, placement_group, - log_stats=not engine_args.disable_log_stats) + log_stats=not engine_args.disable_log_stats, + lora_configs=lora_configs,) # MODIFY return engine def add_request( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py old mode 100644 new mode 100755 index 9dddfc1acd9cc..292a1763b6b67 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -71,6 +71,8 @@ def __init__( seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: int = 4, + lora_paths: List[str] = None, + adapter_names: List[str] = None, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -88,6 +90,8 @@ def __init__( seed=seed, gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, + lora_paths=lora_paths, # MODIFY + adapter_names=adapter_names, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/model_executor/lora_utils.py b/vllm/model_executor/lora_utils.py new file mode 100755 index 0000000000000..2aa8a85b02026 --- /dev/null +++ b/vllm/model_executor/lora_utils.py @@ -0,0 +1,134 @@ +from vllm.model_executor.parallel_utils.layers import BLoraColumnParallelLinear, BLoraRowParallelLinear, ColumnParallelLinear, RowParallelLinear +from peft.tuners.lora import LoraLayer +from peft import LoraConfig +import re +import torch + +WEIGHTS_NAME = "adapter_model.bin" +PREFIX = "base_model.model." +PARAMETER_PREFIX = "lora_" + + +def _get_submodules(model, key): + parent = model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, target, target_name + + +def _create_new_module(lora_config, adapter_name, target): + lora_alpha = lora_config.lora_alpha + r = lora_config.r + lora_dropout = lora_config.lora_dropout + if isinstance(target, ColumnParallelLinear): + new_module = BLoraColumnParallelLinear( + input_size=target.input_size, + output_size=target.output_size_per_partition, + adapter_name=adapter_name, + bias=target.bias, + gather_output=target.gather_output, + skip_bias_add=target.skip_bias_add, + quant_config=target.quant_config, + lora_alpha=lora_alpha, + r=r, + lora_dropout=lora_dropout) + return new_module + if isinstance(target, RowParallelLinear): + new_module = BLoraRowParallelLinear( + input_size=target.input_size_per_partition, + output_size=target.output_size, + adapter_name=adapter_name, + bias=target.bias, + input_is_parallel=target.input_is_parallel, + reduce_results=target.reduce_results, + skip_bias_add=target.skip_bias_add, + quant_config=target.quant_config, + lora_alpha=lora_alpha, + r=r, + lora_dropout=lora_dropout) + return new_module + + +def _replace_module(parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + new_module.weight = child.weight + if getattr(child, "state", None) is not None: + new_module.state = child.state + new_module.to(child.weight.device) + # dispatch to correct device + for name, module in new_module.named_modules(): + if "lora_" in name: + module.to(child.weight.device) + + +def _create_and_replace(lora_config, adapter_name, target, target_name, + parent): + if (isinstance(target, (ColumnParallelLinear, RowParallelLinear)) + and not isinstance(target, LoraLayer)): + new_module = _create_new_module(lora_config, adapter_name, target) + _replace_module(parent, target_name, new_module, target) + elif isinstance(target, LoraLayer): + target.update_layer(adapter_name, lora_config.r, + lora_config.lora_alpha, lora_config.lora_dropout, + lora_config.init_lora_weights) + + +def add_lora_adapter(model: torch.nn.Module, + lora_path: str, + adapter_name: str): + lora_config = LoraConfig.from_pretrained(lora_path, + revision=None, + use_auth_token=None) + key_list = [key for key, _ in model.named_modules()] + + # iterate the modules of LLaMa to insert the LoRA adapter + + # TODO: we should re-construct the logic from here to fit LlaMa LoRA + + for key in key_list: + # find target module + target_module_found = any( + re.match(f".*\\.{target_key}$", key) + for target_key in lora_config.target_modules) or any( + target_key == key for target_key in lora_config.target_modules) + if not target_module_found: + continue + parent, target, target_name = _get_submodules(model, key) + print(f"parent: {parent}, ") + + # create and replace + _create_and_replace(lora_config, adapter_name, target, target_name, + parent) + + adapters_weights = torch.load(f"{lora_path}/{WEIGHTS_NAME}") + + processed_adapter_state_dict = {} + for key, value in adapters_weights.items(): + if key.startswith(PREFIX): + new_key = key[len(PREFIX):] + else: + new_key = key + processed_adapter_state_dict[new_key] = value + + state_dict = {} + for k, v in processed_adapter_state_dict.items(): + if PARAMETER_PREFIX in k: + suffix = k.split(PARAMETER_PREFIX)[1] + if "." in suffix: + to_replace = ".".join(suffix.split(".")[1:]) + k = k.replace(to_replace, f"{adapter_name}.{to_replace}") + else: + k = f"{k}.{adapter_name}" + state_dict[k] = v + + # print("====== LORA ======") + # for name in state_dict.keys(): + # print(name) + + # print("====== MODEL ======") + # for name in model.state_dict().keys(): + # print(name) + + + model.load_lora_weights_parallel(state_dict) + model.cuda() \ No newline at end of file diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py old mode 100644 new mode 100755 index b18f99223f10a..66725d1de3012 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,6 +1,6 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type +from typing import Type, List, Tuple import torch import torch.nn as nn @@ -11,6 +11,8 @@ from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) +from vllm.model_executor.lora_utils import add_lora_adapter # MODIFY + # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { "AquilaModel": AquilaForCausalLM, @@ -64,7 +66,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + lora_configs: List[Tuple[str, str]] = None) -> nn.Module: # MODIFY model_class = _get_model_architecture(model_config.hf_config) # Get the quantization config. @@ -108,4 +111,19 @@ def get_model(model_config: ModelConfig) -> nn.Module: model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) model = model.cuda() + + # print("====== MODEL ======") + # for name in model.state_dict().keys(): + # print(name) + + # MODIFY + # load lora adapter + if lora_configs is not None: + for lora_config in lora_configs: + lora_path = lora_config[0] + adapter_name = lora_config[1] + add_lora_adapter(model=model, + lora_path=lora_path, + adapter_name=adapter_name) + # END return model.eval() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py old mode 100644 new mode 100755 index 735e4ad172182..b517bab513b7d --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -39,7 +39,11 @@ from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, + BLoraColumnParallelLinear, + BLoraRowParallelLinear, + ColumnParallelLinear, + RowParallelLinear) # MODIFY from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -270,6 +274,28 @@ def forward( hidden_states = self.norm(hidden_states) return hidden_states +# MODIFY +class NormHead(ColumnParallelLinear): + + def __init__(self, hidden_size, vocab_size, bias=False): + super().__init__(hidden_size, + vocab_size, + bias=False, + gather_output=False) + self.first_flag = True + + def get_weight(self): + if self.first_flag: + self.first_flag = False + self.weight = nn.Parameter(nn.functional.normalize(self.weight)) + return self.weight + + def forward(self, hidden_states): + if self.first_flag: + self.first_flag = False + self.weight = nn.Parameter(nn.functional.normalize(self.weight)) + return ColumnParallelLinear.forward(self, hidden_states) + class LlamaForCausalLM(nn.Module): @@ -299,6 +325,37 @@ def forward( input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], ) -> SamplerOutput: + + # MODIFY + # create and set lora mask + batch_lora_ids = {} + total_length = input_ids.shape[0] + index = 0 + for seq_groups in input_metadata.seq_groups: + seq_ids = seq_groups[0] + sampling_params = seq_groups[1] + for i in range(len(seq_ids)): + lora_id = sampling_params.lora_id + index_set = batch_lora_ids.get(lora_id, set()) + index_set.add(index) + batch_lora_ids[lora_id] = index_set + index += 1 + # lora mask for compute lora in a batch: lora_id -> mask + lora_masks = {} + for lora_id, pos in batch_lora_ids.items(): + mask = torch.zeros(total_length, device=input_ids.device) + for i in range(total_length): + if i in pos: + mask[i] = 1 + lora_masks[lora_id] = mask + + for _, module in self.model.named_modules(): + if isinstance(module, + (BLoraColumnParallelLinear, BLoraRowParallelLinear)): + module.lora_masks = lora_masks + # END + + hidden_states = self.model(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head.weight, hidden_states, @@ -426,3 +483,37 @@ def load_weights(self, load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights, row_parallel_weights, tp_rank) + + # MODIFY + def load_lora_weights_parallel(self, lora_state_dict: dict): + model_state_dict = self.state_dict() + + # for name in model_state_dict.keys(): + # print(name) + + tp_rank = get_tensor_model_parallel_rank() + for name, loaded_weight in lora_state_dict.items(): + if name not in model_state_dict.keys(): + raise ValueError(f"No module named {name} " + + f"in base model: {model_state_dict.keys()}") + param = model_state_dict[name] + column_parallel_weights = [] + row_parallel_weights = [] + if "W_pack" in name: + if "lora_B" in name: + column_parallel_weights.append("lora_B") + + elif "o_proj" in name: + if "lora_A" in name: + row_parallel_weights.append("lora_A") + else: + raise ValueError("Only support target module for W_pack" + + f"and o_proj now! Target module:{name}") + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weights, + row_parallel_weights, + tp_rank, + ) \ No newline at end of file diff --git a/vllm/model_executor/parallel_utils/layers.py b/vllm/model_executor/parallel_utils/layers.py old mode 100644 new mode 100755 index c1aea2c1d5543..0fe3e7dbd4b6c --- a/vllm/model_executor/parallel_utils/layers.py +++ b/vllm/model_executor/parallel_utils/layers.py @@ -11,6 +11,10 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter +# MODIFY +from peft.tuners.lora import LoraLayer +# END + from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -301,3 +305,101 @@ def forward(self, input_): output = output_ output_bias = self.bias return output, output_bias + +# MODIFY +def compulate_lora(obj: LoraLayer, x: torch.Tensor, output: torch.Tensor, + lora_masks: dict[str, torch.Tensor]) -> torch.Tensor: + lora_out = torch.zeros_like(output) + for lora_id, lora_mask in lora_masks.items(): + # compute lora separately and use mask to filter + if lora_id in obj.lora_A.keys(): + lora_result = obj.scaling[lora_id] * obj.lora_B[lora_id]( + obj.lora_A[lora_id](x)) + lora_out += (lora_result * lora_mask.unsqueeze(1).unsqueeze(2)) + return lora_out + + +class BLoraColumnParallelLinear(ColumnParallelLinear, LoraLayer): + + def __init__( + self, + input_size: int, + output_size: int, + adapter_name: str, + bias: bool = True, + gather_output: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + init_lora_weights = kwargs.pop('init_lora_weights', True) + + ColumnParallelLinear.__init__(self, input_size, output_size, bias, + gather_output, skip_bias_add, + params_dtype, quant_config) + LoraLayer.__init__(self, + in_features=input_size, + out_features=output_size) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, + init_lora_weights) + self.active_adapter_ = adapter_name + + def forward(self, x: torch.Tensor): + previous_dtype = x.dtype + output, output_bias = ColumnParallelLinear.forward(self, x) + x = x.to(self.lora_A[self.active_adapter_].weight.dtype) + lora_out = compulate_lora(self, x, output, self.lora_masks) + output += lora_out + output = output.to(previous_dtype) + if output_bias is not None: + output_bias = output_bias.to(previous_dtype) + + return output, output_bias + + +class BLoraRowParallelLinear(RowParallelLinear, LoraLayer): + + def __init__( + self, + input_size: int, + output_size: int, + adapter_name: str, + bias: bool = True, + input_is_parallel: bool = False, + reduce_results: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs, + ): + init_lora_weights = kwargs.pop('init_lora_weights', True) + + RowParallelLinear.__init__(self, input_size, output_size, bias, + input_is_parallel, skip_bias_add, + params_dtype, reduce_results, quant_config) + LoraLayer.__init__(self, + in_features=input_size, + out_features=output_size) + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, + init_lora_weights) + self.active_adapter_ = adapter_name + + def forward(self, x: torch.Tensor): + previous_dtype = x.dtype + output, output_bias = RowParallelLinear.forward(self, x) + x = x.to(self.lora_A[self.active_adapter_].weight.dtype) + lora_out = compulate_lora(self, x, output, self.lora_masks) + output += lora_out + + output = output.to(previous_dtype) + if output_bias is not None: + output_bias = output_bias.to(previous_dtype) + + return output, output_bias \ No newline at end of file diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py old mode 100644 new mode 100755 index f8ef9be7b6a62..988634db564ab --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -106,6 +106,7 @@ def __init__( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, + lora_id: str = None, # MODIFY ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -118,6 +119,7 @@ def __init__( self.use_beam_search = use_beam_search self.length_penalty = length_penalty self.early_stopping = early_stopping + self.lora_id = lora_id # MODIFY if stop is None: self.stop = [] elif isinstance(stop, str): diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py old mode 100644 new mode 100755 index bbbc2e7f45a6e..8e6b59216418b --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -31,6 +31,7 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, distributed_init_method: Optional[str] = None, + lora_configs: List[Tuple[str, str]] = None, # MODIFY ) -> None: self.model_config = model_config self.parallel_config = parallel_config @@ -38,6 +39,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method + self.lora_configs = lora_configs # MODIFY # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -67,7 +69,8 @@ def init_model(self): # Initialize the model. set_random_seed(self.model_config.seed) - self.model = get_model(self.model_config) + # self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.lora_configs) # MODIFY @torch.inference_mode() def profile_num_available_blocks(