Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
  • Loading branch information
LucasWilkinson and alexm-redhat committed Jan 31, 2025
1 parent f2b2500 commit 2d61054
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 26 deletions.
3 changes: 1 addition & 2 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ def __init__(self, runner: "ModelRunnerBase"):

@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
def graph_capture(self, max_batch_size: int):
"""Context manager used when capturing CUDA graphs."""
yield

Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,7 @@ def _get_decode_wrapper(self):
return self._decode_wrapper

@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
def graph_capture(self, max_batch_size: int):
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),
Expand Down
16 changes: 7 additions & 9 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,7 @@ def graph_clone(self, batch_size: int):
return self.__class__(self.runner)

def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False,
positions: Optional[torch.Tensor] = None):
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing

attn_metadata = self.runner.attn_backend.make_metadata(
Expand Down Expand Up @@ -175,15 +172,16 @@ def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_positions = attn_metadata.input_positions
num_positions = input_positions.shape[0]
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
input_buffers["input_positions"][:attn_metadata.decode_metadata.
input_positions.shape[0]].copy_(
attn_metadata.decode_metadata.
input_positions,
non_blocking=True)
# CUDA graph buffer is padded so only perform a partial copy based on
# num_positions
input_buffers["input_positions"][:num_positions].copy_(
input_positions, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
Expand Down
12 changes: 3 additions & 9 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -289,9 +288,7 @@ def __init__(self, runner: "ModelRunnerBase"):
self._is_graph_capturing = False

@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
assert positions is None
def graph_capture(self, max_batch_size: int):

self._is_graph_capturing = True

Expand All @@ -317,10 +314,7 @@ def graph_clone(self, batch_size: int) -> "CommonAttentionState":
return self.__class__(self.runner)

def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False,
positions: Optional[torch.Tensor] = None):
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
Expand Down
6 changes: 2 additions & 4 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,13 +1483,11 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
self.vllm_config.compilation_config.
cudagraph_capture_sizes)
for batch_size in cudagraph_capture_sizes:
cur_input_positions = input_positions[..., :batch_size]
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
is_encoder_decoder_model=self.model_config.
is_encoder_decoder,
positions=cur_input_positions))
is_encoder_decoder))
# Disable KV Scale Calculation for graph capture
attn_metadata.enable_kv_scales_calculation = False
if self.lora_config:
Expand All @@ -1515,7 +1513,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
"input_ids":
input_tokens[:batch_size],
"positions":
cur_input_positions,
input_positions[..., :batch_size],
"intermediate_inputs":
intermediate_inputs[:batch_size]
if intermediate_inputs is not None else None,
Expand Down

0 comments on commit 2d61054

Please sign in to comment.