Skip to content

Commit

Permalink
refactor: Better parameters and pipeline for text generation
Browse files Browse the repository at this point in the history
vidy-br: better-text-generation
  • Loading branch information
RealVidy committed Aug 29, 2024
1 parent 69e9dd6 commit 598003b
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 54 deletions.
54 changes: 31 additions & 23 deletions llm_inference/hf_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,56 +3,64 @@

import torch
from fastapi import FastAPI
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline

from llm_inference.generated_text_response import GeneratedTextResponse
from llm_inference.prompt_request import PromptRequest

app = FastAPI()


# Login to Hugging Face using the auth token
hf_token = os.getenv("HUGGING_FACE_TOKEN")
if hf_token is None:
raise ValueError(
"Please set your Hugging Face token in the HUGGING_FACE_TOKEN environment variable."
)

# Load the model and tokenizer at startup
# Load the pipeline at startup
MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME, token=hf_token, padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME, torch_dtype=torch.float16, token=hf_token

# Check for the best available device: NVIDIA GPU, MPS, or CPU
if torch.cuda.is_available():
device = torch.device("cuda") # NVIDIA GPU
elif torch.backends.mps.is_available():
device = torch.device("mps") # Apple Silicon GPU (M1, M2, M3, etc.)
else:
device = torch.device("cpu") # Fallback to CPU

text_generator = pipeline(
"text-generation",
model=MODEL_NAME,
tokenizer=MODEL_NAME,
torch_dtype=torch.bfloat16 if torch.backends.mps.is_available() else torch.float16,
device=device,
token=hf_token,
return_full_text=False,
)
device = next(model.parameters()).device
model.to(device)


@app.post("/generate", response_model=GeneratedTextResponse)
async def generate_text(request: PromptRequest):
inputs = tokenizer(request.prompts, return_tensors="pt", padding=True).to(device)

with torch.no_grad():
outputs = await asyncio.to_thread(
model.generate,
**inputs,
max_length=request.max_length,
generated_texts = await asyncio.to_thread(
text_generator,
request.prompts,
do_sample=request.do_sample,
max_new_tokens=request.max_new_tokens,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
num_return_sequences=request.num_return_sequences,
do_sample=request.do_sample,
top_p=request.top_p,
)

generated_texts = [
tokenizer.decode(output, skip_special_tokens=True) for output in outputs
]
# Extract the generated text from the pipeline output
generated_texts = [result[0]["generated_text"] for result in generated_texts]

return GeneratedTextResponse(generated_texts=generated_texts)


if __name__ == "__main__":
import uvicorn

uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run("hf_inference_server:app", host="0.0.0.0", port=8000)
21 changes: 11 additions & 10 deletions llm_inference/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import boto3
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import pipeline


class S3ModelHandler:
Expand All @@ -22,8 +22,7 @@ def __init__(
self.s3_endpoint_url = s3_endpoint_url
self.s3_secret_access_key = s3_secret_access_key
self.s3_access_key_id = s3_access_key_id
self.model = None
self.tokenizer = None
self.text_generator = None

def download_model_from_s3(self):
print(
Expand Down Expand Up @@ -89,14 +88,16 @@ def load_model(self):
print("Loading model...")

model_name_or_path = self.local_model_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, padding_side="left"
)
self.tokenizer.pad_token = self.tokenizer.eos_token

# Initialize the model with model parallelism
device_map = "auto" # Automatically split across available GPUs
self.model = AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.float16, device_map=device_map
self.text_generator = pipeline(
"text-generation",
model=model_name_or_path,
tokenizer=model_name_or_path,
torch_dtype=torch.bfloat16
if torch.backends.mps.is_available()
else torch.float16,
device_map=device_map,
return_full_text=False,
)
print("Model loaded successfully!")
2 changes: 1 addition & 1 deletion llm_inference/prompt_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

class PromptRequest(BaseModel):
prompts: list[str]
max_length: Optional[int] = 128
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
num_return_sequences: Optional[int] = 1
do_sample: Optional[bool] = False
max_new_tokens: Optional[int] = 64
27 changes: 10 additions & 17 deletions llm_inference/s3_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,24 @@
model_handler.download_model_from_s3()
model_handler.load_model()

device = next(model_handler.model.parameters()).device


@app.post("/generate", response_model=GeneratedTextResponse)
async def generate_text(request: PromptRequest):
inputs = model_handler.tokenizer(
request.prompts, return_tensors="pt", padding=True
).to(device)

with torch.no_grad():
outputs = await asyncio.to_thread(
model_handler.model.generate,
**inputs,
max_length=request.max_length,
generated_texts = await asyncio.to_thread(
model_handler.text_generator,
request.prompts,
do_sample=request.do_sample,
max_new_tokens=request.max_new_tokens,
num_return_sequences=request.num_return_sequences,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
num_return_sequences=request.num_return_sequences,
do_sample=request.do_sample,
top_p=request.top_p,
)

generated_texts = [
model_handler.tokenizer.decode(output, skip_special_tokens=True)
for output in outputs
]
# Extract the generated text from the pipeline output
generated_texts = [result[0]["generated_text"] for result in generated_texts]

return GeneratedTextResponse(generated_texts=generated_texts)


Expand Down
3 changes: 0 additions & 3 deletions scripts/example_batch_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
# Create the payload for the request
payload = {
"prompts": prompts,
"max_length": 64,
"num_return_sequences": 1,
"do_sample": False,
}


Expand Down

0 comments on commit 598003b

Please sign in to comment.