Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update LLM server, fix bugs, and format with black #236

Merged
merged 10 commits into from
Nov 19, 2024
5 changes: 5 additions & 0 deletions .github/workflows/test_lemonade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ jobs:
conda install pylint
python -m pip check
pip install -e .[llm]
- name: Lint with Black
uses: psf/black@stable
with:
options: "--check --verbose"
src: "./src"
- name: Lint with PyLint
shell: bash -el {0}
run: |
Expand Down
6 changes: 4 additions & 2 deletions src/turnkeyml/common/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,16 @@ def get_wmic_info(command):
try:
output = subprocess.check_output(command, shell=True).decode()
return output.split("\n")[1].strip()
except Exception as e: # pylint: disable=broad-except
except Exception as e: # pylint: disable=broad-except
return str(e)

if os_type == "Windows":
if shutil.which("wmic") is not None:
info_dict["Processor"] = get_wmic_info("wmic cpu get name")
info_dict["OEM System"] = get_wmic_info("wmic computersystem get model")
mem_info_bytes = get_wmic_info("wmic computersystem get TotalPhysicalMemory")
mem_info_bytes = get_wmic_info(
"wmic computersystem get TotalPhysicalMemory"
)
try:
mem_info_gb = round(int(mem_info_bytes) / (1024**3), 2)
info_dict["Physical Memory"] = f"{mem_info_gb} GB"
Expand Down
4 changes: 0 additions & 4 deletions src/turnkeyml/llm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def main():
except ModuleNotFoundError:
pass





# Define the argument parser
parser = cli.CustomArgumentParser(
description="Turnkey analysis and benchmarking of GenAI models. "
Expand Down
192 changes: 173 additions & 19 deletions src/turnkeyml/llm/tools/chat.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
import argparse
from threading import Thread
import time
import statistics
from threading import Thread, Event
import asyncio
from fastapi import FastAPI, WebSocket
from fastapi.responses import HTMLResponse
from starlette.websockets import WebSocketDisconnect
from pydantic import BaseModel
from transformers import TextIteratorStreamer
from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
import uvicorn
from turnkeyml.state import State
from turnkeyml.tools import Tool
from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter

DEFAULT_GENERATE_PARAMS = {
"do_sample": True,
"top_k": 50,
"top_p": 0.95,
"temperature": 0.7,
}

DEFAULT_SERVER_PORT = 8000


class LLMPrompt(Tool):
"""
Expand Down Expand Up @@ -61,7 +73,9 @@ def run(
tokenizer: TokenizerAdapter = state.tokenizer

input_ids = tokenizer(prompt, return_tensors="pt").input_ids
response = model.generate(input_ids, max_new_tokens=max_new_tokens)
response = model.generate(
input_ids, max_new_tokens=max_new_tokens, **DEFAULT_GENERATE_PARAMS
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True).strip()

state.response = response_text
Expand All @@ -70,16 +84,32 @@ def run(
return state


# Custom huggingface-style stopping criteria to allow
# us to halt streaming in-progress generations
class StopOnEvent(StoppingCriteria):
def __init__(self, stop_event: Event):
super().__init__()
self.stop_event = stop_event

def __call__(self, input_ids, scores, **kwargs):
return self.stop_event.is_set()


class Serve(Tool):
"""
Open a web server that apps can use to communicate with the LLM.

There are two ways interact with the server:
There are two ways to perform generations with the server:
- Send an http request to "http://localhost:8000/generate" and
receive back a response with the complete prompt.
- Open a WebSocket with "ws://localhost:8000" and receive a
streaming response to the prompt.

The server also exposes these helpful endpoints:
- /health: check whether a model is loaded and ready to serve.
- /stats: performance statistics for the generation.
- /halt: stop an in-progress generation from make more tokens.

The WebSocket functionality is demonstrated by the webpage served at
http://localhost:8000, which you can visit with a web browser after
opening the server.
Expand All @@ -89,6 +119,7 @@ class Serve(Tool):
huggingface TextIteratorStreamer.
- state.tokenizer: tokenizer instance used to generate inputs for the
model. Must be compatible with the huggingface TextIteratorStreamer.
- state.checkpoint: name of the checkpoint used to load state.model.

Output state produced: None
"""
Expand All @@ -102,6 +133,17 @@ def __init__(self):
enable_logger=False,
)

# Performance stats that are set during /ws and can be
# fetched in /stats
self.time_to_first_token = None
self.tokens_per_second = None
self.input_tokens = None
self.output_tokens = None
self.decode_token_times = None

# Flag that tells the LLM to stop generating text and end the response
self.stop_event = Event()

@staticmethod
def parser(add_help: bool = True) -> argparse.ArgumentParser:
parser = __class__.helpful_parser(
Expand Down Expand Up @@ -151,10 +193,15 @@ class Message(BaseModel):
<input type="text" id="messageText" autocomplete="off"/>
<button type="submit">Send</button>
</form>
<button onclick="showStats()">Show Stats</button>
<button onclick="halt()">Halt</button>
<button onclick="health()">Health</button>
<p id="allMessages"></p> <!-- Use a <p> element to display all messages -->
<p id="statsMessage"></p> <!-- Use a <p> element to display stats message -->
<script>
const messageQueue = []; // Store incoming messages
const allMessagesContainer = document.getElementById('allMessages'); // Get the container element
const statsMessageContainer = document.getElementById('statsMessage'); // Get the stats message container
var ws = new WebSocket("ws://localhost:8000/ws");
ws.onmessage = function(event) {
const message = event.data;
Expand All @@ -173,6 +220,36 @@ class Message(BaseModel):
input.value = ''
event.preventDefault()
}
function showStats() {
fetch('/stats')
.then(response => response.json())
.then(data => {
statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
})
.catch(error => {
console.error('Error:', error);
});
}
function halt() {
fetch('/halt')
.then(response => response.json())
.then(data => {
statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
})
.catch(error => {
console.error('Error:', error);
});
}
function health() {
fetch('/health')
.then(response => response.json())
.then(data => {
statsMessageContainer.textContent = JSON.stringify(data); // Display the stats message
})
.catch(error => {
console.error('Error:', error);
});
}
</script>
</body>
</html>
Expand All @@ -188,11 +265,8 @@ async def generate_response(message: Message):
response = model.generate(
input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id,
**DEFAULT_GENERATE_PARAMS,
)
generated_text = tokenizer.decode(response[0], skip_special_tokens=True)

Expand All @@ -203,13 +277,23 @@ async def generate_response(message: Message):

@app.websocket("/ws")
async def stream_response(websocket: WebSocket):
"""
Receive a prompt string, and then stream the response back
over a websocket.
"""

await websocket.accept()
while True:

message = await websocket.receive_text()

if message == "done":
try:
message = await websocket.receive_text()
except WebSocketDisconnect:
print("Client closed connection")
break

# Reset the early-exit flag before we start each generation
self.stop_event.clear()

input_ids = tokenizer(message, return_tensors="pt").input_ids

# Set up the generation parameters
Expand All @@ -219,39 +303,109 @@ async def stream_response(websocket: WebSocket):

streamer = oga.OrtGenaiStreamer(tokenizer)

self.input_tokens = len(input_ids)

else:
# Huggingface-like models
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
)

self.input_tokens = len(input_ids[0])

# Enable sending a signal into the generator thread to stop
# the generation early
stopping_criteria = StoppingCriteriaList([StopOnEvent(self.stop_event)])

generation_kwargs = {
"input_ids": input_ids,
"streamer": streamer,
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_k": 50,
"top_p": 0.95,
"temperature": 0.7,
"pad_token_id": tokenizer.eos_token_id,
"stopping_criteria": stopping_criteria,
**DEFAULT_GENERATE_PARAMS,
}

# Initialize performance variables
generation_start_time = time.perf_counter()
first_token = True
self.decode_token_times = []
self.output_tokens = 0

# Begin generation
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

# Generate the response using streaming
for new_text in streamer:

# Capture performance stats about this token
self.output_tokens = self.output_tokens + 1
if first_token:
self.time_to_first_token = (
time.perf_counter() - generation_start_time
)
first_token = False
else:
self.decode_token_times.append(
time.perf_counter() - next_token_start_time
)
next_token_start_time = time.perf_counter()

# Print the decoded value to the terminal for debugging purposes
print(new_text, end="", flush=True)

# Send the generated text to the client
await asyncio.sleep(0.1) # Add a small delay (adjust as needed)
await asyncio.sleep(0.001) # Add a small delay (adjust as needed)
await websocket.send_text(new_text)

# Allow the user to finish the response early
if self.stop_event.is_set():
print("Stopping generation early.")
break

self.tokens_per_second = 1 / statistics.mean(self.decode_token_times)
print("\n")
thread.join()

await websocket.close()

uvicorn.run(app, host="localhost", port=8000)
@app.get("/stats")
async def send_stats():
"""
Send performance statistics to the client.
"""
return {
"time_to_first_token": self.time_to_first_token,
"tokens_per_second": self.tokens_per_second,
"input_tokens": self.input_tokens,
"output_tokens": self.output_tokens,
"decode_token_times": self.decode_token_times,
}

@app.get("/halt")
async def halt_generation():
"""
Allow the client to halt an in-progress generation.
"""

self.stop_event.set()

return {
"terminated": True,
}

@app.get("/health")
async def health():
"""
Report server health information to the client.
"""

self.stop_event.set()

return {
"model_loaded": state.checkpoint,
}

uvicorn.run(app, host="localhost", port=DEFAULT_SERVER_PORT)

return state
13 changes: 10 additions & 3 deletions src/turnkeyml/llm/tools/huggingface_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,15 @@ def __init__(self, model, dtype=torch.float32, device="cpu"):
self.dtype = dtype
self.device = device

def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2,
do_sample=True, temperature=0.1, **kwargs):
def generate(
self,
input_ids,
max_new_tokens=512,
repetition_penalty=1.2,
do_sample=True,
temperature=0.1,
**kwargs,
):
amp_enabled = (
True
if (self.dtype == torch.float16 or self.dtype == torch.bfloat16)
Expand All @@ -221,7 +228,7 @@ def generate(self, input_ids, max_new_tokens=512, repetition_penalty=1.2,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
temperature=temperature,
**kwargs
**kwargs,
)


Expand Down
Loading
Loading