From de5a4a937f482668578bc006012293563979cc7d Mon Sep 17 00:00:00 2001 From: Kai Huang Date: Wed, 5 Jun 2024 11:28:24 +0800 Subject: [PATCH] split qkv --- .../layers/attention/backends/torch_sdpa.py | 27 +++++++++++++------ vllm/model_executor/layers/sampler.py | 3 ++- vllm/model_executor/models/llama.py | 2 +- vllm/worker/model_runner.py | 5 ++++ 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/attention/backends/torch_sdpa.py b/vllm/model_executor/layers/attention/backends/torch_sdpa.py index 2e7b155ea6868..690d40de0e76f 100644 --- a/vllm/model_executor/layers/attention/backends/torch_sdpa.py +++ b/vllm/model_executor/layers/attention/backends/torch_sdpa.py @@ -91,14 +91,25 @@ def forward( query = query.movedim(1, query.dim() - 2) key = key.movedim(1, key.dim() - 2) value = value.movedim(1, value.dim() - 2) - out = torch.nn.functional.scaled_dot_product_attention( - query, - key, - value, - input_metadata.attn_bias, - 0.0, - is_causal=not self.need_mask, - scale=self.scale).movedim(query.dim() - 2, 1).contiguous() + + out = [] + block_size = 16 + query_split = torch.split(query, block_size, dim=1) + key_split = torch.split(key, block_size, dim=1) + value_split = torch.split(value, block_size, dim=1) + for q, k, v in zip(query_split, key_split, value_split): + out_split = torch.nn.functional.scaled_dot_product_attention( + q, k, v, input_metadata.attn_bias, 0.0, is_causal=not self.need_mask, scale=self.scale) + out.append(out_split) + out = torch.cat(out, dim=1).movedim(query.dim() - 2, 1).contiguous() + # out = torch.nn.functional.scaled_dot_product_attention( + # query, + # key, + # value, + # input_metadata.attn_bias, + # 0.0, + # is_causal=not self.need_mask, + # scale=self.scale).movedim(query.dim() - 2, 1).contiguous() # output = out.view_as(query) # FIXME: half input will generate float output, next ipex release will fix this. output = out.view_as(query).to(query.dtype) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 341825035cfd1..0e8469c8f41f0 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -43,7 +43,8 @@ def __init__(self, def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) + # logits = torch.matmul(hidden_states, embedding.t()) + logits = embedding(hidden_states) if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_gather(logits) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c163dfdab537..8a46bb39758d3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -343,7 +343,7 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, + next_tokens = self.sampler(self.lm_head, hidden_states, sampling_metadata) return next_tokens diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7afda48efa207..9ab5c98a1a0f4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -577,6 +577,9 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model + print(input_tokens.shape) + import time + start = time.time() hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, @@ -589,6 +592,8 @@ def execute_model( hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) + end = time.time() + print("Time used: ", (end - start)*1000) return output @torch.inference_mode()