-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #92 from AIDotNet/feature_llamafactory
update llamafactory 0.8.0
- Loading branch information
Showing
146 changed files
with
9,184 additions
and
4,358 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
|
||
import uvicorn | ||
|
||
from llamafactory.api.app import create_app | ||
from llamafactory.chat import ChatModel | ||
|
||
|
||
def main(): | ||
chat_model = ChatModel() | ||
app = create_app(chat_model) | ||
api_host = os.environ.get("API_HOST", "0.0.0.0") | ||
api_port = int(os.environ.get("API_PORT", "8000")) | ||
print("Visit http://localhost:{}/docs for API document.".format(api_port)) | ||
uvicorn.run(app, host=api_host, port=api_port) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Level: api, webui > chat, eval, train > data, model > hparams > extras | ||
|
||
from .cli import VERSION | ||
|
||
|
||
__version__ = VERSION |
File renamed without changes.
108 changes: 108 additions & 0 deletions
108
src/AntSK.LLamaFactory/llamafactory/llamafactory/api/app.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
from contextlib import asynccontextmanager | ||
from typing import Optional | ||
|
||
from typing_extensions import Annotated | ||
|
||
from ..chat import ChatModel | ||
from ..extras.misc import torch_gc | ||
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available | ||
from .chat import ( | ||
create_chat_completion_response, | ||
create_score_evaluation_response, | ||
create_stream_chat_completion_response, | ||
) | ||
from .protocol import ( | ||
ChatCompletionRequest, | ||
ChatCompletionResponse, | ||
ModelCard, | ||
ModelList, | ||
ScoreEvaluationRequest, | ||
ScoreEvaluationResponse, | ||
) | ||
|
||
|
||
if is_fastapi_available(): | ||
from fastapi import Depends, FastAPI, HTTPException, status | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer | ||
|
||
|
||
if is_starlette_available(): | ||
from sse_starlette import EventSourceResponse | ||
|
||
|
||
if is_uvicorn_available(): | ||
import uvicorn | ||
|
||
|
||
@asynccontextmanager | ||
async def lifespan(app: "FastAPI"): # collects GPU memory | ||
yield | ||
torch_gc() | ||
|
||
|
||
def create_app(chat_model: "ChatModel") -> "FastAPI": | ||
app = FastAPI(lifespan=lifespan) | ||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
) | ||
api_key = os.environ.get("API_KEY") | ||
security = HTTPBearer(auto_error=False) | ||
|
||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): | ||
if api_key and (auth is None or auth.credentials != api_key): | ||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.") | ||
|
||
@app.get( | ||
"/v1/models", | ||
response_model=ModelList, | ||
status_code=status.HTTP_200_OK, | ||
dependencies=[Depends(verify_api_key)], | ||
) | ||
async def list_models(): | ||
model_card = ModelCard(id="gpt-3.5-turbo") | ||
return ModelList(data=[model_card]) | ||
|
||
@app.post( | ||
"/v1/chat/completions", | ||
response_model=ChatCompletionResponse, | ||
status_code=status.HTTP_200_OK, | ||
dependencies=[Depends(verify_api_key)], | ||
) | ||
async def create_chat_completion(request: ChatCompletionRequest): | ||
if not chat_model.engine.can_generate: | ||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") | ||
|
||
if request.stream: | ||
generate = create_stream_chat_completion_response(request, chat_model) | ||
return EventSourceResponse(generate, media_type="text/event-stream") | ||
else: | ||
return await create_chat_completion_response(request, chat_model) | ||
|
||
@app.post( | ||
"/v1/score/evaluation", | ||
response_model=ScoreEvaluationResponse, | ||
status_code=status.HTTP_200_OK, | ||
dependencies=[Depends(verify_api_key)], | ||
) | ||
async def create_score_evaluation(request: ScoreEvaluationRequest): | ||
if chat_model.engine.can_generate: | ||
raise HTTPException(status_code=status.HTTP_405_METHOD_NOT_ALLOWED, detail="Not allowed") | ||
|
||
return await create_score_evaluation_response(request, chat_model) | ||
|
||
return app | ||
|
||
|
||
def run_api() -> None: | ||
chat_model = ChatModel() | ||
app = create_app(chat_model) | ||
api_host = os.environ.get("API_HOST", "0.0.0.0") | ||
api_port = int(os.environ.get("API_PORT", "8000")) | ||
print("Visit http://localhost:{}/docs for API document.".format(api_port)) | ||
uvicorn.run(app, host=api_host, port=api_port) |
219 changes: 219 additions & 0 deletions
219
src/AntSK.LLamaFactory/llamafactory/llamafactory/api/chat.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
import base64 | ||
import io | ||
import json | ||
import os | ||
import uuid | ||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple | ||
|
||
from ..data import Role as DataRole | ||
from ..extras.logging import get_logger | ||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available | ||
from .common import dictify, jsonify | ||
from .protocol import ( | ||
ChatCompletionMessage, | ||
ChatCompletionResponse, | ||
ChatCompletionResponseChoice, | ||
ChatCompletionResponseUsage, | ||
ChatCompletionStreamResponse, | ||
ChatCompletionStreamResponseChoice, | ||
Finish, | ||
Function, | ||
FunctionCall, | ||
Role, | ||
ScoreEvaluationResponse, | ||
) | ||
|
||
|
||
if is_fastapi_available(): | ||
from fastapi import HTTPException, status | ||
|
||
|
||
if is_pillow_available(): | ||
from PIL import Image | ||
|
||
|
||
if is_requests_available(): | ||
import requests | ||
|
||
|
||
if TYPE_CHECKING: | ||
from numpy.typing import NDArray | ||
|
||
from ..chat import ChatModel | ||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest | ||
|
||
|
||
logger = get_logger(__name__) | ||
ROLE_MAPPING = { | ||
Role.USER: DataRole.USER.value, | ||
Role.ASSISTANT: DataRole.ASSISTANT.value, | ||
Role.SYSTEM: DataRole.SYSTEM.value, | ||
Role.FUNCTION: DataRole.FUNCTION.value, | ||
Role.TOOL: DataRole.OBSERVATION.value, | ||
} | ||
|
||
|
||
def _process_request( | ||
request: "ChatCompletionRequest", | ||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: | ||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) | ||
|
||
if len(request.messages) == 0: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") | ||
|
||
if request.messages[0].role == Role.SYSTEM: | ||
system = request.messages.pop(0).content | ||
else: | ||
system = None | ||
|
||
if len(request.messages) % 2 == 0: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") | ||
|
||
input_messages = [] | ||
image = None | ||
for i, message in enumerate(request.messages): | ||
if i % 2 == 0 and message.role not in [Role.USER, Role.TOOL]: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") | ||
elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role") | ||
|
||
if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls): | ||
name = message.tool_calls[0].function.name | ||
arguments = message.tool_calls[0].function.arguments | ||
content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False) | ||
input_messages.append({"role": ROLE_MAPPING[Role.FUNCTION], "content": content}) | ||
elif isinstance(message.content, list): | ||
for input_item in message.content: | ||
if input_item.type == "text": | ||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) | ||
else: | ||
image_url = input_item.image_url.url | ||
if image_url.startswith("data:image"): # base64 image | ||
image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) | ||
image_path = io.BytesIO(image_data) | ||
elif os.path.isfile(image_url): # local file | ||
image_path = open(image_url, "rb") | ||
else: # web uri | ||
image_path = requests.get(image_url, stream=True).raw | ||
|
||
image = Image.open(image_path).convert("RGB") | ||
else: | ||
input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) | ||
|
||
tool_list = request.tools | ||
if isinstance(tool_list, list) and len(tool_list): | ||
try: | ||
tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False) | ||
except Exception: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools") | ||
else: | ||
tools = None | ||
|
||
return input_messages, system, tools, image | ||
|
||
|
||
def _create_stream_chat_completion_chunk( | ||
completion_id: str, | ||
model: str, | ||
delta: "ChatCompletionMessage", | ||
index: Optional[int] = 0, | ||
finish_reason: Optional["Finish"] = None, | ||
) -> str: | ||
choice_data = ChatCompletionStreamResponseChoice(index=index, delta=delta, finish_reason=finish_reason) | ||
chunk = ChatCompletionStreamResponse(id=completion_id, model=model, choices=[choice_data]) | ||
return jsonify(chunk) | ||
|
||
|
||
async def create_chat_completion_response( | ||
request: "ChatCompletionRequest", chat_model: "ChatModel" | ||
) -> "ChatCompletionResponse": | ||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) | ||
input_messages, system, tools, image = _process_request(request) | ||
responses = await chat_model.achat( | ||
input_messages, | ||
system, | ||
tools, | ||
image, | ||
do_sample=request.do_sample, | ||
temperature=request.temperature, | ||
top_p=request.top_p, | ||
max_new_tokens=request.max_tokens, | ||
num_return_sequences=request.n, | ||
stop=request.stop, | ||
) | ||
|
||
prompt_length, response_length = 0, 0 | ||
choices = [] | ||
for i, response in enumerate(responses): | ||
if tools: | ||
result = chat_model.engine.template.format_tools.extract(response.response_text) | ||
else: | ||
result = response.response_text | ||
|
||
if isinstance(result, tuple): | ||
name, arguments = result | ||
function = Function(name=name, arguments=arguments) | ||
tool_call = FunctionCall(id="call_{}".format(uuid.uuid4().hex), function=function) | ||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, tool_calls=[tool_call]) | ||
finish_reason = Finish.TOOL | ||
else: | ||
response_message = ChatCompletionMessage(role=Role.ASSISTANT, content=result) | ||
finish_reason = Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH | ||
|
||
choices.append(ChatCompletionResponseChoice(index=i, message=response_message, finish_reason=finish_reason)) | ||
prompt_length = response.prompt_length | ||
response_length += response.response_length | ||
|
||
usage = ChatCompletionResponseUsage( | ||
prompt_tokens=prompt_length, | ||
completion_tokens=response_length, | ||
total_tokens=prompt_length + response_length, | ||
) | ||
|
||
return ChatCompletionResponse(id=completion_id, model=request.model, choices=choices, usage=usage) | ||
|
||
|
||
async def create_stream_chat_completion_response( | ||
request: "ChatCompletionRequest", chat_model: "ChatModel" | ||
) -> AsyncGenerator[str, None]: | ||
completion_id = "chatcmpl-{}".format(uuid.uuid4().hex) | ||
input_messages, system, tools, image = _process_request(request) | ||
if tools: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream function calls.") | ||
|
||
if request.n > 1: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot stream multiple responses.") | ||
|
||
yield _create_stream_chat_completion_chunk( | ||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(role=Role.ASSISTANT, content="") | ||
) | ||
async for new_token in chat_model.astream_chat( | ||
input_messages, | ||
system, | ||
tools, | ||
image, | ||
do_sample=request.do_sample, | ||
temperature=request.temperature, | ||
top_p=request.top_p, | ||
max_new_tokens=request.max_tokens, | ||
stop=request.stop, | ||
): | ||
if len(new_token) != 0: | ||
yield _create_stream_chat_completion_chunk( | ||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(content=new_token) | ||
) | ||
|
||
yield _create_stream_chat_completion_chunk( | ||
completion_id=completion_id, model=request.model, delta=ChatCompletionMessage(), finish_reason=Finish.STOP | ||
) | ||
yield "[DONE]" | ||
|
||
|
||
async def create_score_evaluation_response( | ||
request: "ScoreEvaluationRequest", chat_model: "ChatModel" | ||
) -> "ScoreEvaluationResponse": | ||
if len(request.messages) == 0: | ||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") | ||
|
||
scores = await chat_model.aget_scores(request.messages, max_length=request.max_length) | ||
return ScoreEvaluationResponse(model=request.model, scores=scores) |
Oops, something went wrong.