-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathgpt_api.py
51 lines (43 loc) · 1.67 KB
/
gpt_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from openai.embeddings_utils import get_embedding, cosine_similarity
import pandas as pd
from utils import api_error_warning
import openai
import streamlit as st
def find_top_similar_results(df: pd.DataFrame, query: str, n: int):
if len(df.index) < n:
n = len(df.index)
embedding = create_embedding(query)
df1 = df.copy()
df1["similarities"] = df1["ada_search"].apply(lambda x: cosine_similarity(x, embedding))
best_results = df1.sort_values("similarities", ascending=False).head(n)
return best_results.drop(['similarities', 'ada_search'], axis=1).drop_duplicates(subset=['text'])
def create_embedding(query):
query = query.encode(encoding='ASCII', errors='ignore').decode()
return get_embedding(query, engine="text-embedding-ada-002")
try:
return get_embedding(query, engine="text-embedding-ada-002")
except:
api_error_warning()
st.stop()
def test_api_key(api_key):
openai.api_key = api_key
with st.spinner("Validading API key..."):
try:
get_embedding('a', engine="text-embedding-ada-002")
except:
api_error_warning()
if 'api_key' in st.session_state:
st.session_state.pop('api_key')
st.stop()
def gpt3_call(prompt, tokens: int, temperature: int=1, stop=None):
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=prompt,
max_tokens=tokens,
stop=stop,
temperature=temperature)
return response["choices"][0]['message']["content"].replace('\n', ' \n')
except Exception as e:
print(e)
api_error_warning()