-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_pinecone.py
105 lines (85 loc) · 3.43 KB
/
main_pinecone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Main entrypoint for the app."""
import logging
import pathlib
from pathlib import Path
import sys
from typing import Optional
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.templating import Jinja2Templates
from langchain.vectorstores import VectorStore
from callback import QuestionGenCallbackHandler, StreamingLLMCallbackHandler
from query_data import get_chain
from schemas import ChatResponse
import pinecone
from langchain.vectorstores import Pinecone
from langchain.embeddings import OpenAIEmbeddings
from dotenv import load_dotenv
import os
if getattr(sys, 'frozen', False):
script_location = pathlib.Path(sys.executable).parent.resolve()
else:
script_location = pathlib.Path(__file__).parent.resolve()
load_dotenv(dotenv_path=script_location / '.env')
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")
PINECONE_ENV = os.getenv("PINECONE_ENV")
# initialize pinecone
pinecone.init(
api_key=PINECONE_API_KEY, # find at app.pinecone.io
environment=PINECONE_ENV # next to api key in console
)
app = FastAPI()
templates = Jinja2Templates(directory="templates")
vectorstore: Optional[VectorStore] = None
@app.on_event("startup")
async def startup_event():
logging.info("loading vectorstore")
if not Path("vectorstore.pkl").exists():
raise ValueError(
"vectorstore.pkl does not exist, please run ingest.py first")
with open("vectorstore.pkl", "rb") as f:
global vectorstore
index_name="data-1"
embeddings = OpenAIEmbeddings(model="gpt-4")
vectorstore = Pinecone.from_existing_index(index_name, embeddings)
@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/chat")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
question_handler = QuestionGenCallbackHandler(websocket)
stream_handler = StreamingLLMCallbackHandler(websocket)
chat_history = []
qa_chain = get_chain(vectorstore, question_handler, stream_handler, True)
# Use the below line instead of the above line to enable tracing
# Ensure `langchain-server` is running
# qa_chain = get_chain(vectorstore, question_handler, stream_handler, tracing=True)
while True:
try:
# Receive and send back the client message
question = await websocket.receive_text()
resp = ChatResponse(sender="you", message=question, type="stream")
await websocket.send_json(resp.dict())
# Construct a response
start_resp = ChatResponse(sender="bot", message="", type="start")
await websocket.send_json(start_resp.dict())
result = await qa_chain.acall(
{"question": question, "chat_history": chat_history}
)
chat_history.append((question, result["answer"]))
end_resp = ChatResponse(sender="bot", message="", type="end")
await websocket.send_json(end_resp.dict())
except WebSocketDisconnect:
logging.info("websocket disconnect")
break
except Exception as e:
logging.error(e)
resp = ChatResponse(
sender="bot",
message="Sorry, something went wrong. Try again.",
type="error",
)
await websocket.send_json(resp.dict())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, port=9000)