Skip to content

Commit

Permalink
Fix indexing and search (#327)
Browse files Browse the repository at this point in the history
* Update dependencies

* Round usage amounts
Remove traceback and ligten up the print statements about not having pickles

* Implement new saving and loading from llamaindex
Change querying in search and index to use the new query_engine
New indexing saves in a folder instead of a single .json so all older indexes won't work
Improved internet search price reporting accuracy
Increase the usage rounding by 2 digits
deep search does not use the query_config, implementations have changed

* Format Python code with psf/black push

* bump version

---------

Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com>
  • Loading branch information
cherryroots and github-actions authored Jun 14, 2023
1 parent 2b291ad commit 29cdd38
Show file tree
Hide file tree
Showing 9 changed files with 160 additions and 169 deletions.
8 changes: 4 additions & 4 deletions cogs/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,8 +898,8 @@ async def discord_backup(self, ctx: discord.ApplicationContext, message_limit: i
description="Response mode, doesn't work on deep composed indexes",
guild_ids=ALLOWED_GUILDS,
required=False,
default="default",
choices=["default", "compact", "tree_summarize"],
default="refine",
choices=["refine", "compact", "tree_summarize"],
)
@discord.option(
name="child_branch_factor",
Expand Down Expand Up @@ -1182,8 +1182,8 @@ async def chat(
description="Response mode, doesn't work on deep searches",
guild_ids=ALLOWED_GUILDS,
required=False,
default="default",
choices=["default", "compact", "tree_summarize"],
default="refine",
choices=["refine", "compact", "tree_summarize"],
)
@discord.option(
name="model",
Expand Down
4 changes: 2 additions & 2 deletions cogs/text_service_cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,11 @@ async def on_ready(self):
assert self.conversation_thread_owners is not defaultdict(list)

except Exception:
print("Failed to load from pickles")
print("Failed to load existing pickles")
self.full_conversation_history = defaultdict(list)
self.conversation_threads = {}
self.conversation_thread_owners = defaultdict(list)
traceback.print_exc()
print("Set empty dictionaries, pickles will be saved in the future")

print("Syncing commands...")

Expand Down
2 changes: 1 addition & 1 deletion gpt3discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from models.openai_model import Model


__version__ = "11.6.2"
__version__ = "11.7.0"


PID_FILE = Path("bot.pid")
Expand Down
110 changes: 62 additions & 48 deletions models/index_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import tempfile
import traceback
import asyncio
import json
from collections import defaultdict

import aiohttp
Expand Down Expand Up @@ -36,6 +35,9 @@
from llama_index.readers.schema.base import Document
from llama_index.langchain_helpers.text_splitter import TokenTextSplitter

from llama_index.retrievers import VectorIndexRetriever, TreeSelectLeafRetriever
from llama_index.query_engine import RetrieverQueryEngine, MultiStepQueryEngine

from llama_index import (
GPTVectorStoreIndex,
SimpleDirectoryReader,
Expand All @@ -50,6 +52,9 @@
download_loader,
LLMPredictor,
ServiceContext,
StorageContext,
ResponseSynthesizer,
load_index_from_storage,
)
from llama_index.readers.web import DEFAULT_WEBSITE_EXTRACTOR

Expand Down Expand Up @@ -81,32 +86,41 @@ def get_and_query(
index: [GPTVectorStoreIndex, GPTTreeIndex] = index_storage[
user_id
].get_index_or_throw()
# If multistep is enabled, multistep contains the llm_predictor.

if isinstance(index, GPTTreeIndex):
step_decompose_transform = StepDecomposeQueryTransform(multistep, verbose=True)
response = index.query(
query,
retriever = TreeSelectLeafRetriever(
index=index,
child_branch_factor=child_branch_factor,
refine_template=CHAT_REFINE_PROMPT,
use_async=True,
service_context=service_context,
optimizer=SentenceEmbeddingOptimizer(threshold_cutoff=0.7),
# Optionally have step_decompose_transform as query_transform if multistep is set
)
else:
step_decompose_transform = StepDecomposeQueryTransform(multistep, verbose=True)
if multistep:
index.index_struct.summary = "Provides information about everything you need to know about this topic, use this to answer the question."
response = index.query(
query,
response_mode=response_mode,
similarity_top_k=nodes,
refine_template=CHAT_REFINE_PROMPT,
use_async=True,
service_context=service_context,
optimizer=SentenceEmbeddingOptimizer(threshold_cutoff=0.7),
query_transform=step_decompose_transform if multistep else None,
retriever = VectorIndexRetriever(
index=index, similarity_top_k=nodes, service_context=service_context
)

response_synthesizer = ResponseSynthesizer.from_args(
response_mode=response_mode,
use_async=True,
refine_template=CHAT_REFINE_PROMPT,
optimizer=SentenceEmbeddingOptimizer(threshold_cutoff=0.7),
service_context=service_context,
)

query_engine = RetrieverQueryEngine(
retriever=retriever, response_synthesizer=response_synthesizer
)

multistep_query_engine = MultiStepQueryEngine(
query_engine=query_engine,
query_transform=StepDecomposeQueryTransform(multistep),
index_summary="Provides information about everything you need to know about this topic, use this to answer the question.",
)

if multistep:
response = multistep_query_engine.query(query)
else:
response = query_engine.query(query)

return response


Expand Down Expand Up @@ -154,13 +168,16 @@ def add_index(self, index, user_id, file_name):
parents=True, exist_ok=True
)
# Save the index to file under the user id
file = f"{file_name}_{date.today().month}_{date.today().day}"
file = f"{date.today().month}_{date.today().day}_{file_name}"
# If file is > 93 in length, cut it off to 93
if len(file) > 93:
file = file[:93]

index.save_to_disk(
EnvService.save_path() / "indexes" / f"{str(user_id)}" / f"{file}.json"
index.storage_context.persist(
persist_dir=EnvService.save_path()
/ "indexes"
/ f"{str(user_id)}"
/ f"{file}"
)

def reset_indexes(self, user_id):
Expand Down Expand Up @@ -210,8 +227,6 @@ async def rename_index(self, ctx, original_path, rename_path):

# Rename the file at f"indexes/{ctx.user.id}/{user_index}" to f"indexes/{ctx.user.id}/{new_name}" using Pathlib
try:
if not rename_path.endswith(".json"):
rename_path = rename_path + ".json"
Path(original_path).rename(rename_path)
return True
except Exception as e:
Expand All @@ -227,10 +242,10 @@ async def execute_index_chat_message(self, ctx, message):

if message.lower() in ["stop", "end", "quit", "exit"]:
await ctx.reply("Ending chat session.")
self.index_chat_chains.pop(message.channel.id)
self.index_chat_chains.pop(ctx.channel.id)

# close the thread
thread = await self.bot.fetch_channel(message.channel.id)
thread = await self.bot.fetch_channel(ctx.channel.id)
await thread.edit(name="Closed-GPT")
await thread.edit(archived=True)
return "Ended chat session."
Expand Down Expand Up @@ -259,14 +274,18 @@ async def start_index_chat(self, ctx, search, user, model):
)

summary_response = await self.loop.run_in_executor(
None, partial(index.query, "What is a summary of this document?")
None,
partial(
index.as_query_engine().query, "What is a summary of this document?"
),
)

query_engine = index.as_query_engine(similarity_top_k=3)

tool_config = IndexToolConfig(
index=index,
query_engine=query_engine,
name=f"Vector Index",
description=f"useful for when you want to answer queries about the external data you're connected to. The data you're connected to is: {summary_response}",
index_query_kwargs={"similarity_top_k": 3},
tool_kwargs={"return_direct": True},
)
toolkit = LlamaToolkit(
Expand Down Expand Up @@ -392,10 +411,8 @@ def index_github_repository(self, link, embed_model):
return index

def index_load_file(self, file_path) -> [GPTVectorStoreIndex, ComposableGraph]:
try:
index = GPTTreeIndex.load_from_disk(file_path)
except AssertionError:
index = GPTVectorStoreIndex.load_from_disk(file_path)
storage_context = StorageContext.from_defaults(persist_dir=file_path)
index = load_index_from_storage(storage_context)
return index

def index_discord(self, document, embed_model) -> GPTVectorStoreIndex:
Expand Down Expand Up @@ -929,13 +946,11 @@ async def compose_indexes(self, user_id, indexes, name, deep_compose):

# Now we have a list of tree indexes, we can compose them
if not name:
name = (
f"composed_deep_index_{date.today().month}_{date.today().day}.json"
)
name = f"{date.today().month}_{date.today().day}_composed_deep_index"

# Save the composed index
tree_index.save_to_disk(
EnvService.save_path() / "indexes" / str(user_id) / name
tree_index.storage_context.persist(
persist_dir=EnvService.save_path() / "indexes" / str(user_id) / name
)

self.index_storage[user_id].queryable_index = tree_index
Expand Down Expand Up @@ -965,11 +980,11 @@ async def compose_indexes(self, user_id, indexes, name, deep_compose):
)

if not name:
name = f"composed_index_{date.today().month}_{date.today().day}.json"
name = f"{date.today().month}_{date.today().day}_composed_index"

# Save the composed index
simple_index.save_to_disk(
EnvService.save_path() / "indexes" / str(user_id) / name
simple_index.storage_context.persist(
persist_dir=EnvService.save_path() / "indexes" / str(user_id) / name
)
self.index_storage[user_id].queryable_index = simple_index

Expand Down Expand Up @@ -1014,11 +1029,11 @@ async def backup_discord(
Path(EnvService.save_path() / "indexes" / str(ctx.guild.id)).mkdir(
parents=True, exist_ok=True
)
index.save_to_disk(
EnvService.save_path()
index.storage_context.persist(
persist_dir=EnvService.save_path()
/ "indexes"
/ str(ctx.guild.id)
/ f"{ctx.guild.name.replace(' ', '-')}_{date.today().month}_{date.today().day}.json"
/ f"{ctx.guild.name.replace(' ', '-')}_{date.today().month}_{date.today().day}"
)

await ctx.respond(embed=EmbedStatics.get_index_set_success_embed(price))
Expand Down Expand Up @@ -1113,8 +1128,7 @@ async def query(
await ctx_response.edit(
embed=EmbedStatics.get_index_query_failure_embed(
"Failed to send query. You may not have an index set, load an index with /index load"
),
delete_after=10,
)
)

# Extracted functions from DiscordReader
Expand Down
Loading

0 comments on commit 29cdd38

Please sign in to comment.