Skip to content

Commit

Permalink
decouple flashinfer files from flash attention (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
alfredgui2 authored Jun 22, 2024
1 parent 9b3c098 commit 2311872
Show file tree
Hide file tree
Showing 23 changed files with 846 additions and 83 deletions.
18 changes: 12 additions & 6 deletions server/examples/test_local_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from text_generation_server.pb import generate_pb2
import torch
from text_generation_server.models.flashinfer_llama import FlashinferLlama
from text_generation_server.models.flashinfer_gemma import FlashinferGemma
from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama
from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma
import sys

try:
from text_generation_server.models.flashinfer_mistral import FlashinferMistral
from text_generation_server.models.flashinfer_phi import FlashinferPhi
from text_generation_server.models.flashinfer_qwen2 import FlashinferQwen2
from text_generation_server.models_flashinfer.flashinfer_mistral import (
FlashinferMistral,
)
from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi
from text_generation_server.models_flashinfer.flashinfer_qwen2 import (
FlashinferQwen2,
)
except:
print("can't load flashinfer mistral and phi and qwen2 without flash attn")

from text_generation_server.models.flashinfer_causal_lm import FlashinferBatch
from text_generation_server.models_flashinfer.flashinfer_causal_lm import (
FlashinferBatch,
)
import random, json
from test_cases import DEMO, LoraSpec

Expand Down
39 changes: 28 additions & 11 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
use_flashinfer: bool = True,
):
if sharded:
assert (
Expand Down Expand Up @@ -90,17 +91,33 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
lora_ids,
)

if use_flashinfer:
from text_generation_server import server_flashinfer

server_flashinfer.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
lora_ids,
)
else:
server.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
lora_ids,
)


@app.command()
Expand Down
56 changes: 0 additions & 56 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,6 @@
__all__.append(FlashGemma)
__all__.append(FlashCohere)

FLASHINFER_AVAILABLE = True
try:
from text_generation_server.models.flashinfer_llama import FlashinferLlama
from text_generation_server.models.flashinfer_gemma import FlashinferGemma
from text_generation_server.models.flashinfer_mistral import FlashinferMistral
from text_generation_server.models.flashinfer_phi import FlashinferPhi
from text_generation_server.models.flashinfer_qwen2 import FlashinferQwen2

except ImportError as e:
logger.warning(f"Could not import FlashInfer: {e}")
FLASHINFER_AVAILABLE = False

if FLASHINFER_AVAILABLE:
__all__.append(FlashinferLlama)
__all__.append(FlashinferGemma)
__all__.append(FlashinferMistral)
__all__.append(FlashinferPhi)
__all__.append(FlashinferQwen2)

MAMBA_AVAILABLE = True
try:
from text_generation_server.models.mamba import Mamba
Expand Down Expand Up @@ -564,13 +545,6 @@ def get_model(
)

elif model_type == PHI:
if FLASHINFER_AVAILABLE:
return FlashinferPhi(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
if FLASH_ATTENTION:
return FlashPhi(
model_id,
Expand Down Expand Up @@ -606,14 +580,6 @@ def get_model(
)

elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASHINFER_AVAILABLE:
return FlashinferLlama(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)

if FLASH_ATTENTION:
return FlashLlama(
model_id,
Expand All @@ -635,13 +601,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == GEMMA:
if FLASHINFER_AVAILABLE:
return FlashinferGemma(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
if FLASH_ATTENTION:
return FlashGemma(
model_id,
Expand Down Expand Up @@ -742,14 +701,6 @@ def get_model(
)

if model_type == MISTRAL:
if FLASHINFER_AVAILABLE:
return FlashinferMistral(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)

sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMistral(
Expand Down Expand Up @@ -820,13 +771,6 @@ def get_model(
)

if model_type == QWEN2:
if FLASHINFER_AVAILABLE:
return FlashinferQwen2(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
sliding_window = config_dict.get("sliding_window", -1)
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
return FlashQwen2(
Expand Down
164 changes: 164 additions & 0 deletions server/text_generation_server/models_flashinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
import enum

from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from typing import Optional

from text_generation_server.models.model import Model
from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama
from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma
from text_generation_server.models_flashinfer.flashinfer_mistral import (
FlashinferMistral,
)
from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi
from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
"Model",
"FlashinferLlama",
"FlashinferGemma",
"FlashinferMistral",
"FlashinferPhi",
"FlashinferQwen2",
"get_model",
]


class ModelType(enum.Enum):
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}


def get_model(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
lora_ids: Optional[str],
) -> Model:
if dtype is None:
if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params.
dtype = torch.float16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")

config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
if model_type is None:
raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}"
)
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}")
quantize = method
else:
logger.info(f"Unknown quantization method {method}")

if quantize == "exl2" and sharded:
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)

if model_type == PHI:
return FlashinferPhi(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
return FlashinferLlama(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == GEMMA:
return FlashinferGemma(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == MISTRAL:
return FlashinferMistral(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == QWEN2:
return FlashinferQwen2(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)

raise ValueError(f"Unsupported model type {model_type}")
Empty file.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed
from typing import Optional, List
from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_gemma_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_gemma_modeling import (
GemmaTokenizerFast,
GemmaConfig,
FlashGemmaForCausalLM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from transformers.models.llama import LlamaTokenizer
from typing import Optional, List

from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_llama_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed
from typing import Optional, List
from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_mistral_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_mistral_modeling import (
MistralConfig,
FlashMistralForCausalLM,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed
from typing import Optional, List
from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_phi_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_phi_modeling import (
FlashPhiForCausalLM,
PhiConfig,
)
Expand Down
Loading

0 comments on commit 2311872

Please sign in to comment.