Skip to content

Commit

Permalink
🚧 Updated response model
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilippGawlik committed Jul 18, 2024
1 parent a77b009 commit 2a5993f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
20 changes: 16 additions & 4 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional, Generator
from fastapi import FastAPI, Depends, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
import uvicorn

from langchain_core.documents.base import Document

from interface.response_models import ResponseModel
from interface.response_models import ResponseModel, CTA, CTAType
from interface.request_models import RequestModel
from src.context import get_context
#from src.generate_with_azure import generate_answer
Expand Down Expand Up @@ -51,14 +53,24 @@ def answer_a_question(query: RequestModel) -> ResponseModel:
context = get_context(query.question)
prompt = assemble_prompt(query.question, context)
answer = generate_answer(prompt)
refs = [c.metadata["title"] for c in context]
return ResponseModel(
status="ok",
msg="Successfully generated answer",
answer=answer,
cta=[c.metadata["metadata_storage_path"] for c in context],
refs=[c.metadata["title"] for c in context]
cta=list(generate_cta(context)),
refs=refs
)


def generate_cta(context: list[Optional[Document]]) -> Generator[CTA, None, None]:
for c in context:
yield CTA(
type=CTAType.LINK,
text=c.metadata["title"],
payload=f"todo/{c.metadata['title']}"
)


if __name__ == "__main__":
uvicorn.run(APP, host="0.0.0.0", port=3000, timeout_keep_alive=20)
uvicorn.run(APP, host="0.0.0.0", port=3000)
19 changes: 18 additions & 1 deletion interface/response_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,27 @@
from pydantic import BaseModel
from typing import Mapping, Any
from enum import Enum


class CTAType(Enum):
MAIL = "MAIL"
CALL = "CALL"
CITATION = "CITATION"
LINK = "LINK"
MEDIA = "MEDIA"
TEAMS_CHAT = "TEAMS_CHAT"
CALENDAR = "CALENDAR"


class CTA(BaseModel):
type: CTAType
text: str
payload: Any


class ResponseModel(BaseModel):
answer: str
refs: list[str]
status: str
msg: str
cta: list[Any]
cta: list[CTA]

0 comments on commit 2a5993f

Please sign in to comment.