-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0619926
commit 922f9d2
Showing
7 changed files
with
341 additions
and
187 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,141 +1,133 @@ | ||
import asyncio | ||
import logging | ||
import os | ||
from typing import Dict, Any, Optional | ||
from typing import Dict, Any, List, Optional | ||
|
||
import torch | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | ||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_model_dir, get_torch_device | ||
from huggingface_hub import file_download, hf_hub_download | ||
from threading import Thread | ||
from typing import AsyncGenerator, Union, Dict, Any, Optional, List | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# class LLMGeneratePipeline(Pipeline): | ||
# def __init__(self, model_id: str): | ||
# self.model_id = model_id | ||
# kwargs = { | ||
# "cache_dir": get_model_dir() | ||
# } | ||
# self.device = get_torch_device() | ||
# folder_name = file_download.repo_folder_name( | ||
# repo_id=model_id, repo_type="model" | ||
# ) | ||
# folder_path = os.path.join(get_model_dir(), folder_name) | ||
|
||
# # Check for fp16 variant | ||
# has_fp16_variant = any( | ||
# ".fp16.safetensors" in fname | ||
# for _, _, files in os.walk(folder_path) | ||
# for fname in files | ||
# ) | ||
# if self.device != "cpu" and has_fp16_variant: | ||
# logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) | ||
# kwargs["torch_dtype"] = torch.float16 | ||
# kwargs["variant"] = "fp16" | ||
|
||
# # Load tokenizer | ||
# self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
|
||
# # Load model | ||
# self.model = AutoModelForCausalLM.from_pretrained( | ||
# model_id, **kwargs).to(self.device) | ||
|
||
# # Set up generation config | ||
# self.generation_config = self.model.generation_config | ||
|
||
# # Optional: Add optimizations | ||
# sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" | ||
# if sfast_enabled: | ||
# logger.info( | ||
# "LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", | ||
# model_id, | ||
# ) | ||
# from app.pipelines.optim.sfast import compile_model | ||
# self.model = compile_model(self.model) | ||
|
||
# def __call__(self, prompt: str, system_msg: Optional[str] = None, | ||
# temperature: Optional[float] = None, | ||
# max_tokens: Optional[int] = None, **kwargs) -> Dict[str, Any]: | ||
# if system_msg: | ||
# input_text = f"{system_msg}\n\n{prompt}" | ||
# else: | ||
# input_text = prompt | ||
|
||
# input_ids = self.tokenizer.encode( | ||
# input_text, return_tensors="pt").to(self.device) | ||
|
||
# # Update generation config | ||
# gen_kwargs = {} | ||
# if temperature is not None: | ||
# gen_kwargs['temperature'] = temperature | ||
# if max_tokens is not None: | ||
# gen_kwargs['max_new_tokens'] = max_tokens | ||
|
||
# # Merge generation config with provided kwargs | ||
# gen_kwargs = {**self.generation_config.to_dict(), **gen_kwargs, **kwargs} | ||
|
||
# # Generate response | ||
# with torch.no_grad(): | ||
# output = self.model.generate( | ||
# input_ids, | ||
# **gen_kwargs | ||
# ) | ||
|
||
# # Decode the response | ||
# response = self.tokenizer.decode(output[0], skip_special_tokens=True) | ||
|
||
# # Calculate tokens used | ||
# tokens_used = len(output[0]) | ||
|
||
# return { | ||
# "response": response.strip(), | ||
# "tokens_used": tokens_used | ||
# } | ||
|
||
# def __str__(self) -> str: | ||
# return f"LLMPipeline model_id={self.model_id}" | ||
|
||
|
||
class LLMGeneratePipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
self.device = get_torch_device() | ||
|
||
kwargs = { | ||
"cache_dir": get_model_dir(), | ||
"device_map": "auto", | ||
"torch_dtype": torch.bfloat16 if self.device != "cpu" else torch.float32, | ||
"cache_dir": get_model_dir() | ||
} | ||
|
||
logger.info(f"Loading model {model_id}") | ||
self.pipeline = pipeline( | ||
"text-generation", | ||
model=model_id, | ||
tokenizer=model_id, | ||
**kwargs | ||
self.device = get_torch_device() | ||
folder_name = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model" | ||
) | ||
folder_path = os.path.join(get_model_dir(), folder_name) | ||
|
||
def __call__(self, prompt: str, system_msg: str = None, **kwargs): | ||
messages = [] | ||
if system_msg: | ||
messages.append({"role": "system", "content": system_msg}) | ||
messages.append({"role": "user", "content": prompt}) | ||
|
||
outputs = self.pipeline( | ||
messages, | ||
max_new_tokens=kwargs.get("max_tokens", 256), | ||
temperature=kwargs.get("temperature", 0.7), | ||
# Check for fp16 variant | ||
has_fp16_variant = any( | ||
".fp16.safetensors" in fname | ||
for _, _, files in os.walk(folder_path) | ||
for fname in files | ||
) | ||
if self.device != "cpu" and has_fp16_variant: | ||
logger.info("LLMGeneratePipeline loading fp16 variant for %s", model_id) | ||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
elif self.device != "cpu": | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
|
||
response = outputs[0]["generated_text"] | ||
# Assuming the response is the last message in the conversation | ||
response = response.split("assistant:")[-1].strip() | ||
# Add device mapping | ||
kwargs["device_map"] = "auto" | ||
|
||
return { | ||
"response": response, | ||
"tokens_used": len(self.pipeline.tokenizer.encode(response)) | ||
} | ||
logger.info(f"Loading model {model_id}") | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | ||
self.model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) | ||
|
||
# Set up generation config | ||
self.generation_config = self.model.generation_config | ||
|
||
self.terminators = [ | ||
self.tokenizer.eos_token_id, | ||
self.tokenizer.convert_tokens_to_ids("<|eot_id|>") | ||
] | ||
|
||
# Optional: Add optimizations | ||
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true" | ||
if sfast_enabled: | ||
logger.info( | ||
"LLMGeneratePipeline will be dynamically compiled with stable-fast for %s", | ||
model_id, | ||
) | ||
from app.pipelines.optim.sfast import compile_model | ||
self.model = compile_model(self.model) | ||
|
||
async def __call__(self, prompt: str, history: Optional[List[tuple]] = None, system_msg: Optional[str] = None, **kwargs) -> AsyncGenerator[Union[str, Dict[str, Any]], None]: | ||
conversation = [] | ||
if system_msg: | ||
conversation.append({"role": "system", "content": system_msg}) | ||
if history: | ||
for user, assistant in history: | ||
conversation.extend([{"role": "user", "content": user}, { | ||
"role": "assistant", "content": assistant}]) | ||
conversation.append({"role": "user", "content": prompt}) | ||
|
||
input_ids = self.tokenizer.apply_chat_template( | ||
conversation, return_tensors="pt").to(self.model.device) | ||
attention_mask = torch.ones_like(input_ids) | ||
|
||
max_new_tokens = kwargs.get("max_tokens", 256) | ||
temperature = kwargs.get("temperature", 0.7) | ||
|
||
streamer = TextIteratorStreamer( | ||
self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | ||
|
||
# Start with the generation config | ||
generate_kwargs = self.generation_config.to_dict() | ||
# Update with our specific parameters | ||
generate_kwargs.update({ | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"streamer": streamer, | ||
"max_new_tokens": max_new_tokens, | ||
"do_sample": True, | ||
"temperature": temperature, | ||
"eos_token_id": self.tokenizer.eos_token_id, | ||
"pad_token_id": self.tokenizer.eos_token_id, | ||
}) | ||
|
||
# This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash. | ||
if temperature == 0: | ||
generate_kwargs['do_sample'] = False | ||
|
||
# Start generation in a separate thread | ||
thread = Thread(target=self.model_generate_wrapper, kwargs=generate_kwargs) | ||
thread.start() | ||
|
||
total_tokens = 0 | ||
try: | ||
for text in streamer: | ||
total_tokens += 1 | ||
yield text | ||
await asyncio.sleep(0) # Allow other tasks to run | ||
except Exception as e: | ||
logger.error(f"Error during streaming: {str(e)}") | ||
raise | ||
|
||
input_length = input_ids.size(1) | ||
yield {"tokens_used": input_length + total_tokens} | ||
|
||
def __str__(self): | ||
return f"LLMGeneratePipeline(model_id={self.model_id})" | ||
|
||
def model_generate_wrapper(self, **kwargs): | ||
try: | ||
logger.debug("Entering model.generate") | ||
with torch.cuda.amp.autocast(): # Use automatic mixed precision | ||
self.model.generate(**kwargs) | ||
logger.debug("Exiting model.generate") | ||
except Exception as e: | ||
logger.error(f"Error in model.generate: {str(e)}", exc_info=True) | ||
raise |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
import subprocess | ||
|
||
print(f"PyTorch version: {torch.__version__}") | ||
print(f"CUDA available: {torch.cuda.is_available()}") | ||
if torch.cuda.is_available(): | ||
print(f"CUDA version: {torch.version.cuda}") | ||
|
||
# Check system CUDA version | ||
try: | ||
nvcc_output = subprocess.check_output(["nvcc", "--version"]).decode("utf-8") | ||
cuda_version = nvcc_output.split("release ")[-1].split(",")[0] | ||
print(f"System CUDA version: {cuda_version}") | ||
except: | ||
print("Unable to check system CUDA version") | ||
|
||
# Print the current device | ||
print(f"Current device: {torch.cuda.get_device_name(0)}") |
Oops, something went wrong.