Skip to content

Commit

Permalink
Merge pull request #17 from Predacons/feature/upgrade
Browse files Browse the repository at this point in the history
minor fixes
  • Loading branch information
shouryashashank authored Oct 27, 2024
2 parents 7648ff5 + 808aca3 commit 408b39f
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 27 deletions.
14 changes: 8 additions & 6 deletions chat_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@ class Message:
@dataclass
class Conversation:
messages: List[Message]
max_tokens: int
temperature: float
frequency_penalty: float
presence_penalty: float
top_p: float
stop: Optional[str]
max_tokens: int = 500
temperature: float = 0.7
frequency_penalty: float = None
presence_penalty: float = None
top_p: float = None
stop: Optional[str] = None
model: Optional[str] = None
encoding_format: str = None

@dataclass
class FilterCategory:
Expand Down
6 changes: 4 additions & 2 deletions embedding_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,7 @@ class EmbeddingInput:
@dataclass
class EmbeddingInput:
input: str
user: str
input_type: str
user: str = None
input_type: str = None
model: str = None
encoding_format: str = None
11 changes: 3 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ async def startup_event():
# Load models
global predacons_models
predacons_models = await PredaconsRepo.load_models()
print(predacons_models)
# Load tokenizers


Expand Down Expand Up @@ -56,27 +55,24 @@ async def get_api_key(
)


@app.post("/deployments/{model}/chat/completions", dependencies=[Depends(get_api_key)])
@app.post("/openai/deployments/{model}/chat/completions", dependencies=[Depends(get_api_key)])
async def chat_completions(request: Request,model:str ,api_version:str = Query(default=None, alias="api-version")):
body = await request.json()
print("Entry Chat Completions")
print(model)
if model not in predacons_models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Model {model} not found",
)

print(body)
print(model)
print(api_version)
return await ChatService.completions(body, predacons_models[model], api_version)

@app.post("/deployments/{model}/completions", dependencies=[Depends(get_api_key)])
@app.post("/openai/deployments/{model}/completions", dependencies=[Depends(get_api_key)])
async def nocontext_completions_endpoint(request: Request, model:str, api_version:str = Query(default=None, alias="api-version")):
body = await request.json()
print("Entry NoContext Completions Endpoint")
print(model)
if model not in predacons_models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -88,11 +84,10 @@ async def nocontext_completions_endpoint(request: Request, model:str, api_versio
print(api_version)
return await ChatService.nocontext_completions(body, predacons_models[model], api_version)

@app.post("/deployments/{model}/embeddings", dependencies=[Depends(get_api_key)])
@app.post("/openai/deployments/{model}/embeddings", dependencies=[Depends(get_api_key)])
async def embeddings_endpoint(request: Request,model:str, api_version:str = Query(default=None, alias="api-version")):
body = await request.json()
print("Entry Embeddings Endpoint")
print(model)
if model not in predacons_models:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down
2 changes: 0 additions & 2 deletions repo/predacons.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ async def load_model(model_name:str):
# model = model_name + path
# tokenizers = model_name + "tokenizer"
predacons_model = PredaconsModel(model_name, path, trust_remote_code, use_fast_generation, draft_model_name, model, tokenizers)
print(f"Model {model_name} loaded: {model}")
print(f"Tokenizer {model_name} loaded: {tokenizers}")
end_time = time.time()
print(f"Model {model_name} loaded successfully in {end_time - start_time} seconds.")
return predacons_model
Expand Down
11 changes: 2 additions & 9 deletions service/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import dotenv
from chat_class import Message, Conversation, ContentFilterResults, Choice, PromptFilterResults, Usage, ChatResponse,FilterCategory
from embedding_class import Embedding, Usage, EmbeddingResponse,EmbeddingInput
from embedding_class import Embedding, Usage as embedding_usage, EmbeddingResponse,EmbeddingInput

dotenv.load_dotenv()

Expand All @@ -22,7 +22,6 @@ async def generate_cmpl_id():

async def completions(conversation_body:str, model_dict, api_version:str = None):
print("Entry Chat Completions")
print(model_dict)
print(api_version)
system_fingerprint = os.getenv('system_fingerprint')
# print(conversation_body)
Expand All @@ -37,8 +36,6 @@ async def completions(conversation_body:str, model_dict, api_version:str = None)
trust_remote_code = model_dict.trust_remote_code
fast_gen = model_dict.use_fast_generation
draft_model = model_dict.draft_model_name
print(model)
print(tokenizer)

response = predacons.chat_generate(model = model,
sequence = conversation.messages,
Expand Down Expand Up @@ -84,7 +81,6 @@ async def completions(conversation_body:str, model_dict, api_version:str = None)

async def nocontext_completions(conversation_body:str, model_dict, api_version:str = None):
print("Entry NoContext Completions")
print(model_dict)
print(api_version)
system_fingerprint = os.getenv('system_fingerprint')

Expand All @@ -110,8 +106,6 @@ async def nocontext_completions(conversation_body:str, model_dict, api_version:s
trust_remote_code = model_dict.trust_remote_code
fast_gen = model_dict.use_fast_generation
draft_model = model_dict.draft_model_name
print(model)
print(tokenizer)

output,tokenizer = predacons.generate(model = model,
sequence = message_str,
Expand Down Expand Up @@ -154,7 +148,6 @@ async def nocontext_completions(conversation_body:str, model_dict, api_version:s

async def embeddings(body, model_dict,model, api_version:str = None):
print("Entry Embeddings")
print(model_dict)
print(api_version)

embedding_input = EmbeddingInput(**body)
Expand All @@ -168,7 +161,7 @@ async def embeddings(body, model_dict,model, api_version:str = None):
for i, embedding in enumerate(embeddings):
embeddings_list.append(Embedding(object = "embedding", index = i, embedding = embedding))

usage = Usage(prompt_tokens = len(embeddings_list), total_tokens = len(embeddings_list))
usage = embedding_usage(prompt_tokens = len(embeddings_list), total_tokens = len(embeddings_list))
embedding_response = EmbeddingResponse(object = "list", data = embeddings_list, model = model, usage = usage)
return embedding_response

0 comments on commit 408b39f

Please sign in to comment.