From 1ab029d4a9d8eeaeacad70ae328a134f8c861977 Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:50:55 +0800 Subject: [PATCH] Fix cuda oom (#7) add torch.cuda.empty_cache() --- vllm/model_executor/models/bigdl_llama.py | 25 ++++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/bigdl_llama.py b/vllm/model_executor/models/bigdl_llama.py index 8683ab256a7af..00402cf7d6903 100644 --- a/vllm/model_executor/models/bigdl_llama.py +++ b/vllm/model_executor/models/bigdl_llama.py @@ -48,6 +48,7 @@ def __init__( self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.dtype = self.model.config.torch_dtype + # self.tmp_kv_cache = [] def decode(self, generated_ids: List[int]) -> str: return self.tokenizer.decode( @@ -76,7 +77,6 @@ def forward( all_decoding = all_decoding and (not seq_group_meta_data.is_prompt) seq_ids = list(seq_group_meta_data.seq_data.keys()) seq_id = seq_ids[0] - print(seq_id) cur_seq_ids.append(seq_id) seq_data = seq_group_meta_data.seq_data[seq_id] @@ -93,9 +93,13 @@ def forward( for seq_group_meta_data in seq_group_meta_data_lists: seq_ids = list(seq_group_meta_data.seq_data.keys()) seq_id = seq_ids[0] + if kv_cache.get(seq_id) is None: + continue for i in range(kv_cache_0): for j in range(kv_cache_1): - bigdl_kv_cache[i][j] = torch.cat((bigdl_kv_cache[i][j], kv_cache[seq_id][i][j]), dim=0).to(dtype = self.dtype) + target_size = (bigdl_kv_cache[i][j].size(0) + kv_cache[seq_id][i][j].size(0),) + kv_cache[seq_id][i][j].size()[1:] + bigdl_kv_cache[i][j].resize_(target_size) + bigdl_kv_cache[i][j][-kv_cache[seq_id][i][j].size(0):] = kv_cache[seq_id][i][j] bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device) bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device) @@ -115,12 +119,13 @@ def forward( "use_cache": True, "return_dict": True, } - # kwargs["position_ids"] = position_ids # pdb.set_trace() outputs = self.model.forward(**kwargs) + # self.tmp_kv_cache = outputs.past_key_values index = 0 bigdl_output = [] for seq_id in cur_seq_ids: + # pdb.set_trace() cur_sampling_params = bigdl_sampling_params[seq_id] logits_processor = prepare_logits_processor( cur_sampling_params.temperature, 1, @@ -143,16 +148,16 @@ def forward( kv_cache[seq_id] = [[[] for _ in range(kv_cache_1)] for _ in range(kv_cache_0)] for i in range(kv_cache_0): for j in range(kv_cache_1): - kv_cache[seq_id][i][j] = outputs.past_key_values[i][j][index].unsqueeze(0).to(device=self.device,dtype = self.dtype) + kv_cache[seq_id][i][j] = outputs.past_key_values[i][j][index].unsqueeze(0) index = index + 1 - - #pdb.set_trace() + + torch.cuda.empty_cache() return bigdl_output def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): pass