Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AutoModel supports FA2/paged attention #2133

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/tests/models/test_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers import AutoTokenizer

from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.transformers_causal_lm import CausalLMBatch
from text_generation_server.utils import weight_hub_files, download_weights
from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded

Expand Down
5 changes: 4 additions & 1 deletion server/tests/models/test_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
from transformers import AutoTokenizer

from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.models.transformers_causal_lm import (
TransformersCausalLM,
CausalLMBatch,
)


@pytest.fixture(scope="session")
Expand Down
2 changes: 1 addition & 1 deletion server/tests/models/test_santacoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLMBatch
from text_generation_server.models.transformers_causal_lm import CausalLMBatch
from text_generation_server.models.santacoder import SantaCoder


Expand Down
99 changes: 65 additions & 34 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from huggingface_hub import hf_hub_download, HfApi
from typing import Optional, List
from pathlib import Path

import transformers
from text_generation_server.utils.speculate import get_speculate, set_speculate
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.transformers_causal_lm import TransformersCausalLM
from text_generation_server.models.transformers_flash_causal_lm import (
TransformersFlashCausalLM,
)
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
Expand All @@ -24,6 +27,8 @@
from text_generation_server.models.gpt_neox import GPTNeoxSharded
from text_generation_server.models.phi import Phi

from text_generation_server.models.globals import USE_CUSTOM_MODELING

from text_generation_server.utils.import_utils import SYSTEM

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
Expand Down Expand Up @@ -288,6 +293,31 @@ def get_model(
)
model_type = config_dict.get("model_type", None)

transformers_causal_lm_class = TransformersCausalLM
if (
not USE_CUSTOM_MODELING
and model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
logger.info(
"TGI's flash enabled models could either not be loaded or are disabled, using Transformers fallback."
)
transformers_model_class = getattr(
transformers, modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]
)

if (
transformers_model_class._supports_flash_attn_2
and transformers_model_class._supports_cache_class
):
logger.info(
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersFlashCausalLM with ragged tensors (single dimension for batch and sequence length)."
)
transformers_causal_lm_class = TransformersFlashCausalLM
else:
logger.info(
f"Transformers' {model_type} implementation supports custom cache and flash/paged attention. Using TransformersCausalLM with classic tensors with padding (two dimensions for batch size and sequence length)."
)

speculator = None
if "medusa_num_heads" in config_dict:
medusa_model_id = model_id
Expand Down Expand Up @@ -449,7 +479,7 @@ def get_model(
or model_type == GPT2
and model_id.startswith("bigcode/")
):
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashSantacoderSharded(
model_id,
revision,
Expand Down Expand Up @@ -491,7 +521,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GPT2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
try:
return FlashGPT2(
model_id,
Expand All @@ -504,7 +534,8 @@ def get_model(
except RuntimeError as e:
# Lots of legacy models with various weight names.
logger.warning(f"Couldn't load flash gpt2 variant: {e}")
return CausalLM(

return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -515,7 +546,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -524,7 +555,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
elif model_type == GPT_NEOX:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashNeoXSharded(
model_id,
revision,
Expand All @@ -543,7 +574,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -553,7 +584,7 @@ def get_model(
)

elif model_type == PHI:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashPhi(
model_id,
revision,
Expand All @@ -563,7 +594,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -573,7 +604,7 @@ def get_model(
)

elif model_type == "phi-msft":
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
raise NotImplementedError(
"Legacy phi-msft is not supported with Flash Attention"
)
Expand All @@ -588,7 +619,7 @@ def get_model(
)

elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashLlama(
model_id,
revision,
Expand All @@ -601,7 +632,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -610,7 +641,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == GEMMA:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashGemma(
model_id,
revision,
Expand All @@ -622,7 +653,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -632,7 +663,7 @@ def get_model(
)

if model_type == COHERE:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashCohere(
model_id,
revision,
Expand All @@ -644,7 +675,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -654,7 +685,7 @@ def get_model(
)

if model_type == DBRX:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashDbrx(
model_id,
revision,
Expand All @@ -666,7 +697,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -677,7 +708,7 @@ def get_model(

if model_type in ["RefinedWeb", "RefinedWebModel", FALCON]:
if sharded:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
if config_dict.get("alibi", False):
raise NotImplementedError("sharded is not supported for this model")
return FlashRWSharded(
Expand Down Expand Up @@ -710,7 +741,7 @@ def get_model(
)

if model_type == MISTRAL:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashMistral(
model_id,
revision,
Expand All @@ -722,7 +753,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -732,7 +763,7 @@ def get_model(
)

if model_type == MIXTRAL:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashMixtral(
model_id,
revision,
Expand All @@ -744,7 +775,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -754,7 +785,7 @@ def get_model(
)

if model_type == STARCODER2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashStarcoder2(
model_id,
revision,
Expand All @@ -767,7 +798,7 @@ def get_model(
FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2")
)
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -777,7 +808,7 @@ def get_model(
)

if model_type == QWEN2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return FlashQwen2(
model_id,
revision,
Expand All @@ -788,7 +819,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2"))
else:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand Down Expand Up @@ -817,7 +848,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == IDEFICS:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return IDEFICSSharded(
model_id,
revision,
Expand All @@ -829,7 +860,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == IDEFICS2:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return Idefics2(
model_id,
revision,
Expand All @@ -841,7 +872,7 @@ def get_model(
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == "paligemma":
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return PaliGemma(
model_id,
revision,
Expand All @@ -854,7 +885,7 @@ def get_model(
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))

if model_type == LLAVA_NEXT:
if FLASH_ATTENTION:
if FLASH_ATTENTION and USE_CUSTOM_MODELING:
return LlavaNext(
model_id,
revision,
Expand All @@ -881,7 +912,7 @@ def get_model(
elif quantize == "exl2":
raise NotImplementedError("exl2 quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand All @@ -902,7 +933,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
return transformers_causal_lm_class(
model_id,
revision,
quantize=quantize,
Expand Down
Loading
Loading