Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolate flashinfer inference code path from mainstream TGI code paths #41

Merged
merged 1 commit into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions server/examples/test_local_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from text_generation_server.pb import generate_pb2
import torch
from text_generation_server.models.flashinfer_llama import FlashinferLlama
from text_generation_server.models.flashinfer_gemma import FlashinferGemma
from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama
from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma
import sys

try:
from text_generation_server.models.flashinfer_mistral import FlashinferMistral
from text_generation_server.models.flashinfer_phi import FlashinferPhi
from text_generation_server.models.flashinfer_qwen2 import FlashinferQwen2
from text_generation_server.models_flashinfer.flashinfer_mistral import (
FlashinferMistral,
)
from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi
from text_generation_server.models_flashinfer.flashinfer_qwen2 import (
FlashinferQwen2,
)
except:
print("can't load flashinfer mistral and phi and qwen2 without flash attn")

from text_generation_server.models.flashinfer_causal_lm import FlashinferBatch
from text_generation_server.models_flashinfer.flashinfer_causal_lm import (
FlashinferBatch,
)
import random, json
from test_cases import DEMO, LoraSpec

Expand Down
39 changes: 28 additions & 11 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def serve(
logger_level: str = "INFO",
json_output: bool = False,
otlp_endpoint: Optional[str] = None,
use_flashinfer: bool = True,
):
if sharded:
assert (
Expand Down Expand Up @@ -90,17 +91,33 @@ def serve(
raise RuntimeError(
"Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model."
)
server.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
lora_ids,
)

if use_flashinfer:
from text_generation_server import server_flashinfer

server_flashinfer.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
lora_ids,
)
else:
server.serve(
model_id,
revision,
sharded,
quantize,
speculate,
dtype,
trust_remote_code,
uds_path,
lora_ids,
)


@app.command()
Expand Down
56 changes: 0 additions & 56 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,6 @@
__all__.append(FlashGemma)
__all__.append(FlashCohere)

FLASHINFER_AVAILABLE = True
try:
from text_generation_server.models.flashinfer_llama import FlashinferLlama
from text_generation_server.models.flashinfer_gemma import FlashinferGemma
from text_generation_server.models.flashinfer_mistral import FlashinferMistral
from text_generation_server.models.flashinfer_phi import FlashinferPhi
from text_generation_server.models.flashinfer_qwen2 import FlashinferQwen2

except ImportError as e:
logger.warning(f"Could not import FlashInfer: {e}")
FLASHINFER_AVAILABLE = False

if FLASHINFER_AVAILABLE:
__all__.append(FlashinferLlama)
__all__.append(FlashinferGemma)
__all__.append(FlashinferMistral)
__all__.append(FlashinferPhi)
__all__.append(FlashinferQwen2)

MAMBA_AVAILABLE = True
try:
from text_generation_server.models.mamba import Mamba
Expand Down Expand Up @@ -564,13 +545,6 @@ def get_model(
)

elif model_type == PHI:
if FLASHINFER_AVAILABLE:
return FlashinferPhi(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
if FLASH_ATTENTION:
return FlashPhi(
model_id,
Expand Down Expand Up @@ -606,14 +580,6 @@ def get_model(
)

elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
if FLASHINFER_AVAILABLE:
return FlashinferLlama(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)

if FLASH_ATTENTION:
return FlashLlama(
model_id,
Expand All @@ -635,13 +601,6 @@ def get_model(
trust_remote_code=trust_remote_code,
)
if model_type == GEMMA:
if FLASHINFER_AVAILABLE:
return FlashinferGemma(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
if FLASH_ATTENTION:
return FlashGemma(
model_id,
Expand Down Expand Up @@ -742,14 +701,6 @@ def get_model(
)

if model_type == MISTRAL:
if FLASHINFER_AVAILABLE:
return FlashinferMistral(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)

sliding_window = config_dict.get("sliding_window", -1)
if FLASH_ATTENTION:
return FlashMistral(
Expand Down Expand Up @@ -820,13 +771,6 @@ def get_model(
)

if model_type == QWEN2:
if FLASHINFER_AVAILABLE:
return FlashinferQwen2(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
sliding_window = config_dict.get("sliding_window", -1)
if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING:
return FlashQwen2(
Expand Down
164 changes: 164 additions & 0 deletions server/text_generation_server/models_flashinfer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import torch
import enum

from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from typing import Optional

from text_generation_server.models.model import Model
from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama
from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma
from text_generation_server.models_flashinfer.flashinfer_mistral import (
FlashinferMistral,
)
from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi
from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True

# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

# Disable gradients
torch.set_grad_enabled(False)

__all__ = [
"Model",
"FlashinferLlama",
"FlashinferGemma",
"FlashinferMistral",
"FlashinferPhi",
"FlashinferQwen2",
"get_model",
]


class ModelType(enum.Enum):
LLAVA_NEXT = {
"type": "llava_next",
"name": "Llava Next (1.6)",
"url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf",
"multimodal": True,
}
LLAMA = {
"type": "llama",
"name": "Llama",
"url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct",
}
PHI3 = {
"type": "phi3",
"name": "Phi 3",
"url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct",
}
GEMMA = {
"type": "gemma",
"name": "Gemma",
"url": "https://huggingface.co/google/gemma-7b",
}
MISTRAL = {
"type": "mistral",
"name": "Mistral",
"url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2",
}
PHI = {
"type": "phi",
"name": "Phi",
"url": "https://huggingface.co/microsoft/phi-1_5",
}
BAICHUAN = {
"type": "baichuan",
"name": "Baichuan",
"url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat",
}
QWEN2 = {
"type": "qwen2",
"name": "Qwen 2",
"url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1",
}


def get_model(
model_id: str,
revision: Optional[str],
sharded: bool,
quantize: Optional[str],
dtype: Optional[str],
trust_remote_code: bool,
lora_ids: Optional[str],
) -> Model:
if dtype is None:
if quantize in ["awq", "exl2", "gptq"]:
# These quantizers only work with float16 params.
dtype = torch.float16
else:
# Keep it as default for now and let
# every model resolve their own default dtype.
dtype = None
elif dtype == "float16":
dtype = torch.float16
elif dtype == "bfloat16":
dtype = torch.bfloat16
else:
raise RuntimeError(f"Unknown dtype {dtype}")

config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
if model_type is None:
raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}"
)
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
logger.info(f"Auto selecting quantization method {method}")
quantize = method
else:
logger.info(f"Unknown quantization method {method}")

if quantize == "exl2" and sharded:
raise RuntimeError(
"Sharding is currently not supported with `exl2` quantization"
)

if model_type == PHI:
return FlashinferPhi(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3:
return FlashinferLlama(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == GEMMA:
return FlashinferGemma(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == MISTRAL:
return FlashinferMistral(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)
elif model_type == QWEN2:
return FlashinferQwen2(
model_id,
lora_ids.split(";") if lora_ids else None,
quantize=quantize,
dtype=dtype,
)

raise ValueError(f"Unsupported model type {model_type}")
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed
from typing import Optional, List
from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_gemma_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_gemma_modeling import (
GemmaTokenizerFast,
GemmaConfig,
FlashGemmaForCausalLM,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from transformers.models.llama import LlamaTokenizer
from typing import Optional, List

from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_llama_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_llama_modeling import (
FlashLlamaForCausalLM,
)
from text_generation_server.utils import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed
from typing import Optional, List
from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_mistral_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_mistral_modeling import (
MistralConfig,
FlashMistralForCausalLM,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import torch.distributed
from typing import Optional, List
from text_generation_server.models.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models.custom_modeling.flashinfer_phi_modeling import (
from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM
from text_generation_server.models_flashinfer.custom_modeling.flashinfer_phi_modeling import (
FlashPhiForCausalLM,
PhiConfig,
)
Expand Down
Loading
Loading