Skip to content

Commit

Permalink
Remove redundant code paths
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Jan 31, 2025
1 parent 339ba27 commit d47b834
Showing 1 changed file with 19 additions and 42 deletions.
61 changes: 19 additions & 42 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad

if TYPE_CHECKING:
Expand Down Expand Up @@ -346,22 +345,15 @@ def graph_capture_get_metadata_for_batch(
use_cuda_graph=True,
)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
if current_platform.is_rocm():
assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), (
f"Expected attn_backend name to be 'ROCM_FLASH', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
else:
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or " \
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)

return attn_metadata

Expand All @@ -377,20 +369,13 @@ def get_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.
if current_platform.is_rocm():
assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), (
f"Expected attn_backend name to be 'ROCM_FLASH', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
else:
assert self.runner.attn_backend.get_name() in\
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
f"got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers

def prepare_graph_input_buffers(
Expand All @@ -405,21 +390,13 @@ def prepare_graph_input_buffers(
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers and
# Flash Attention backend. Assert the same.

if current_platform.is_rocm():
assert (self.runner.attn_backend.get_name() == "ROCM_FLASH"), (
f"Expected attn_backend name to be 'ROCM_FLASH', but "
f" got '{self.runner.attn_backend.get_name()}'")
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
else:
assert self.runner.attn_backend.get_name() in\
assert self.runner.attn_backend.get_name() in\
["XFORMERS", "FLASH_ATTN"], \
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"Expected attn_backend name to be either 'XFORMERS' or "\
f"'FLASH_ATTN', but "\
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
f"got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)

def begin_forward(self, model_input) -> None:
return
Expand Down

0 comments on commit d47b834

Please sign in to comment.