Skip to content

Commit

Permalink
update search to use new models
Browse files Browse the repository at this point in the history
  • Loading branch information
Kav-K committed Nov 22, 2023
1 parent 19c1d27 commit b3024ff
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 25 deletions.
18 changes: 10 additions & 8 deletions cogs/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ async def chat(
input_type=discord.SlashCommandOptionType.integer,
max_value=16,
min_value=1,
default=3,
)
@discord.option(
name="nodes",
Expand All @@ -1331,6 +1332,7 @@ async def chat(
input_type=discord.SlashCommandOptionType.integer,
max_value=8,
min_value=1,
default=4,
)
@discord.option(
name="deep",
Expand All @@ -1343,14 +1345,14 @@ async def chat(
description="Response mode, doesn't work on deep searches",
guild_ids=ALLOWED_GUILDS,
required=False,
default="refine",
default="compact",
choices=["refine", "compact", "tree_summarize"],
)
@discord.option(
name="model",
description="The model to use for the request (querying, not composition)",
required=False,
default="gpt-4-32k",
default="gpt-4-1106-preview",
autocomplete=Settings_autocompleter.get_index_and_search_models,
)
@discord.option(
Expand All @@ -1365,12 +1367,12 @@ async def search(
self,
ctx: discord.ApplicationContext,
query: str,
scope: int,
nodes: int,
deep: bool,
response_mode: str,
model: str,
multistep: bool,
scope: int = 3,
nodes: int = 4,
deep: bool = False,
response_mode: str = "refine",
model: str ="gpt-4-1106-preview",
multistep: bool = False,
):
await self.search_cog.search_command(
ctx,
Expand Down
31 changes: 14 additions & 17 deletions models/search_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,14 @@ async def search(
multistep=False,
redo=None,
):
DEFAULT_SEARCH_NODES = 1
DEFAULT_SEARCH_NODES = 4
if not user_api_key:
os.environ["OPENAI_API_KEY"] = self.openai_key
else:
os.environ["OPENAI_API_KEY"] = user_api_key
openai.api_key = os.environ["OPENAI_API_KEY"]


# Initialize the search cost
price = 0

Expand All @@ -239,28 +240,21 @@ async def search(
)

try:
llm_predictor_presearch = OpenAI(
max_tokens=50,
temperature=0.4,
llm_predictor_presearch = ChatOpenAI(
max_tokens=100,
temperature=0,
presence_penalty=0.65,
model_name="text-davinci-003",
model_name=model,
)

# Refine a query to send to google custom search API
prompt = f"You are to be given a search query for google. Change the query such that putting it into the Google Custom Search API will return the most relevant websites to assist in answering the original query. If the original query is inferring knowledge about the current day, insert the current day into the refined prompt. If the original query is inferring knowledge about the current month, insert the current month and year into the refined prompt. If the original query is inferring knowledge about the current year, insert the current year into the refined prompt. Generally, if the original query is inferring knowledge about something that happened recently, insert the current month into the refined query. Avoid inserting a day, month, or year for queries that purely ask about facts and about things that don't have much time-relevance. The current date is {str(datetime.now().date())}. Do not insert the current date if not neccessary. Respond with only the refined query for the original query. Don’t use punctuation or quotation marks.\n\nExamples:\n---\nOriginal Query: ‘Who is Harald Baldr?’\nRefined Query: ‘Harald Baldr biography’\n---\nOriginal Query: ‘What happened today with the Ohio train derailment?’\nRefined Query: ‘Ohio train derailment details {str(datetime.now().date())}\n---\nOriginal Query: ‘Is copper in drinking water bad for you?’\nRefined Query: ‘copper in drinking water adverse effects’\n---\nOriginal Query: What's the current time in Mississauga?\nRefined Query: current time Mississauga\nNow, refine the user input query.\nOriginal Query: {query}\nRefined Query:"
query_refined = await llm_predictor_presearch.agenerate(
prompts=[prompt],
query_refined = await llm_predictor_presearch.apredict(
text=prompt,
)
query_refined_text = query_refined.generations[0][0].text
query_refined_text = query_refined

await self.usage_service.update_usage(
query_refined.llm_output.get("token_usage").get("total_tokens"),
"davinci",
)
price += await self.usage_service.get_price(
query_refined.llm_output.get("token_usage").get("total_tokens"),
"davinci",
)
print("The query refined text is: " + query_refined_text)

except Exception as e:
traceback.print_exc()
Expand Down Expand Up @@ -345,7 +339,10 @@ async def search(

embedding_model = OpenAIEmbedding()

llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model_name=model))
if "vision" in model:
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model=model, max_tokens=4096))
else:
llm_predictor = LLMPredictor(llm=ChatOpenAI(temperature=0, model=model))

token_counter = TokenCountingHandler(
tokenizer=tiktoken.encoding_for_model(model).encode, verbose=False
Expand Down

0 comments on commit b3024ff

Please sign in to comment.