Skip to content

Commit

Permalink
generic chat format
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisDiachkov committed Sep 10, 2024
1 parent fa2a9d3 commit 90fd660
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 39 deletions.
6 changes: 4 additions & 2 deletions src/AGISwarm/llm_instruct_ms/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, config: LLMInstructConfig):
self.sampling_settings_cls = ENGINE_SAMPLING_PARAMS_MAP[config.engine]
self.queue_manager = AsyncIOQueueManager(
max_concurrent_tasks=5,
sleep_time=0.0001,
sleep_time=0,
)
self.start_abort_lock = asyncio.Lock()
self.setup_routes()
Expand Down Expand Up @@ -82,8 +82,8 @@ async def gui(self):
async def generate(self, websocket: WebSocket): # type: ignore
"""WebSocket endpoint"""
await websocket.accept()
conversation_id = str(uuid.uuid4())
try:
conversation_id = str(uuid.uuid4())
while True:
data: Dict[str, Any] = await websocket.receive_json()
gen_config = SamplingConfig(data)
Expand Down Expand Up @@ -115,6 +115,7 @@ async def generate(self, websocket: WebSocket): # type: ignore
gen_config.reply_prefix,
sampling_dict,
):
await asyncio.sleep(0)
if "status" not in step_info: # Task's return value.
await websocket.send_json(
{
Expand Down Expand Up @@ -159,6 +160,7 @@ async def generate(self, websocket: WebSocket): # type: ignore
except WebSocketDisconnect:
print("Client disconnected", flush=True)
finally:
self.llm_pipeline.conversations.pop(conversation_id, None)
await websocket.close()

class AbortRequest(BaseModel):
Expand Down
62 changes: 37 additions & 25 deletions src/AGISwarm/llm_instruct_ms/llm_engines/engine.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Utility functions for LLM engines"""

import uuid
from abc import abstractmethod
from typing import Dict, Generic, List, TypeVar, cast

from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase


class SamplingParams(BaseModel):
Expand All @@ -20,7 +22,38 @@ class SamplingParams(BaseModel):


# pylint: disable=too-few-public-methods
class Engine(Generic[_SamplingParams_contra]):
class PreparePromptMixin:
"""Prepare prompt mixin"""

def prepare_prompt(
self,
processor: PreTrainedTokenizerBase,
messages: List[Dict[str, str]],
reply_prefix: str = "",
):
"""Prepare prompt for model"""
reply_prefix += " "
messages.append({"role": "assistant", "content": reply_prefix.strip()})
eot_uuid = "eot_" + str(uuid.uuid4())
prompt = (
cast(
str,
processor.apply_chat_template(
messages,
tokenize=False,
# continue_final_message=True,
add_generation_prompt=False,
),
)
+ eot_uuid
)
prompt = prompt.replace(processor.eos_token + eot_uuid, "")
prompt = prompt.replace(eot_uuid, "")
return prompt


# pylint: disable=too-few-public-methods
class Engine(Generic[_SamplingParams_contra], PreparePromptMixin):
"""Engine protocol"""

conversations: Dict[str, List[Dict[str, str]]]
Expand All @@ -34,6 +67,8 @@ async def __call__(
reply_prefix: str,
sampling_params: _SamplingParams_contra,
):
if conversation_id not in self.conversations:
self.conversations[conversation_id] = []
if system_prompt != "":
self.conversations[conversation_id].append(
{
Expand Down Expand Up @@ -67,7 +102,7 @@ async def generate(


# pylint: disable=too-few-public-methods
class ConcurrentEngine(Generic[_SamplingParams_contra]):
class ConcurrentEngine(Generic[_SamplingParams_contra], PreparePromptMixin):
"""Concurrent engine protocol"""

conversations: Dict[str, List[Dict[str, str]]]
Expand Down Expand Up @@ -113,26 +148,3 @@ async def generate(
):
"""Generate text from prompt"""
yield str()


def prepare_prompt(
tokenizer: object,
messages: List[Dict[str, str]],
reply_prefix: str | None = None,
tokenize: bool = False,
):
"""Prepare prompt for model"""
if reply_prefix == "":
reply_prefix = None
prompt = cast(
str,
tokenizer.apply_chat_template( # type: ignore
messages,
tokenize=tokenize,
add_generation_prompt=reply_prefix is None,
),
) + ("$" if reply_prefix else "")
if reply_prefix:
prompt = prompt.replace("<|eot_id|>$", "<|eot_id|>assistant\n\n" + reply_prefix)

return prompt
10 changes: 6 additions & 4 deletions src/AGISwarm/llm_instruct_ms/llm_engines/hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torch
import transformers # type: ignore
from pydantic import Field
from transformers import AutoTokenizer # type: ignore

from .engine import Engine, SamplingParams, prepare_prompt
from .engine import Engine, SamplingParams

SUPPORTED_MODELS = [
"meta-llama/Meta-Llama-3-8B-Instruct",
Expand Down Expand Up @@ -57,7 +58,9 @@ def __init__(
tokenizer_name: str | None,
):

self.tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
self.conversations: Dict[str, List[Dict]] = {}
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or hf_model_name)

self.pipeline = cast(
transformers.TextGenerationPipeline,
transformers.pipeline(
Expand All @@ -72,7 +75,6 @@ def __init__(
},
),
)
self.conversations: Dict[str, List[Dict]] = {}

async def generate(
self,
Expand All @@ -84,7 +86,7 @@ async def generate(
streamer = transformers.TextIteratorStreamer(
self.tokenizer, skip_prompt=True, skip_special_tokens=True # type: ignore
)
prompt = prepare_prompt(self.tokenizer, messages, reply_prefix)
prompt = self.prepare_prompt(self.tokenizer, messages, reply_prefix)
thread = Thread(
target=self.pipeline,
kwargs={
Expand Down
8 changes: 4 additions & 4 deletions src/AGISwarm/llm_instruct_ms/llm_engines/llama_cpp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from llama_cpp import CreateCompletionStreamResponse, Llama
from pydantic import Field
from transformers import AutoTokenizer # type: ignore
from transformers import PreTrainedTokenizer

from .engine import Engine, SamplingParams, prepare_prompt
from .engine import Engine, SamplingParams


class LlamaCppSamplingParams(SamplingParams):
Expand All @@ -31,7 +31,7 @@ def __init__( # pylint: disable=too-many-arguments
self.llama = Llama.from_pretrained(
hf_model_name, filename=filename, n_gpu_layers=n_gpu_layers, n_ctx=n_ctx
)
self.tokenizer: object = AutoTokenizer.from_pretrained(
self.tokenizer: PreTrainedTokenizer = PreTrainedTokenizer.from_pretrained(
tokenizer_name or hf_model_name
)
self.conversations: Dict[str, List[Dict]] = {}
Expand All @@ -52,7 +52,7 @@ async def generate(
sampling_params: LlamaCppSamplingParams = LlamaCppSamplingParams(),
):
"""Generate text from prompt"""
prompt = prepare_prompt(self.tokenizer, messages, reply_prefix)
prompt = self.prepare_prompt(self.tokenizer, messages, reply_prefix)
sampling_params_dict = self.get_sampling_params(sampling_params)
if reply_prefix:
yield reply_prefix
Expand Down
8 changes: 4 additions & 4 deletions src/AGISwarm/llm_instruct_ms/llm_engines/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from huggingface_hub import hf_hub_download
from pydantic import Field

from .engine import ConcurrentEngine, SamplingParams, prepare_prompt
from .engine import ConcurrentEngine, SamplingParams


class VLLMSamplingParams(SamplingParams):
Expand All @@ -32,6 +32,7 @@ def __init__(
model = hf_hub_download(hf_model_name, filename)
else:
model = hf_model_name
self.conversations: Dict[str, List[Dict]] = {}
self.model = vllm.AsyncLLMEngine.from_engine_args(
vllm.AsyncEngineArgs(
model=model,
Expand All @@ -44,7 +45,6 @@ def __init__(
)
logging.info("Model loaded")
self.tokenizer = asyncio.run(self.model.get_tokenizer())
self.conversations: Dict[str, List[Dict]] = {}

def get_sampling_params(
self, sampling_params: VLLMSamplingParams
Expand All @@ -65,12 +65,12 @@ def get_sampling_params(
async def generate(
self,
messages: list[dict],
reply_prefix: str | None,
reply_prefix: str,
sampling_params: VLLMSamplingParams,
task_id: str,
):
"""Generate text from prompt"""
prompt = prepare_prompt(self.tokenizer, messages, reply_prefix)
prompt = self.prepare_prompt(self.tokenizer, messages, reply_prefix)
vllm_sampling_params = self.get_sampling_params(sampling_params)
current_len = 0
if reply_prefix:
Expand Down

0 comments on commit 90fd660

Please sign in to comment.