Skip to content

Commit

Permalink
Faster local indexing and fix rag chat output
Browse files Browse the repository at this point in the history
  • Loading branch information
AmanKishore committed Oct 23, 2023
1 parent 5239de7 commit 1db0653
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
2 changes: 1 addition & 1 deletion mirageml/commands/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def chat(files: list[str] = [], urls: list[str] = [], sources: list[str] = []):
chat_history = [{"role": "system", "content": "You are a helpful assistant."}]
ai_response = ""
if sources:
chat_history = rag_chat(sources)
chat_history, ai_response = rag_chat(sources)

while True:
# Loop for follow-up questions
Expand Down
16 changes: 13 additions & 3 deletions mirageml/commands/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
qdrant_search,
remote_qdrant_search,
)
from .utils.codeblocks import add_indices_to_code_blocks

console = Console()
config = load_config()
Expand Down Expand Up @@ -99,9 +100,9 @@ def rag_chat(sources):
box=HORIZONTALS,
),
console=console,
screen=False,
transient=True,
auto_refresh=True,
vertical_overflow="visible",
refresh_per_second=8,
) as live:
sorted_hits = search_and_rank(live, user_input, sources)
sources_used = list(set([hit["payload"]["source"] for hit in sorted_hits]))
Expand Down Expand Up @@ -161,4 +162,13 @@ def rag_chat(sources):
)
)
chat_history.append({"role": "assistant", "content": ai_response})
return chat_history
indexed_ai_response = add_indices_to_code_blocks(ai_response)
console.print(
Panel(
Markdown(indexed_ai_response),
title="[bold blue]Assistant[/bold blue]",
box=HORIZONTALS,
border_style="blue",
)
)
return chat_history, ai_response
58 changes: 33 additions & 25 deletions mirageml/commands/utils/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import requests
import tiktoken
import typer
from concurrent.futures import ThreadPoolExecutor

from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, PointStruct, VectorParams
from rich.console import Console
Expand Down Expand Up @@ -54,6 +56,34 @@ def create_remote_qdrant_db(collection_name, link=None, path=None):
if path:
data, metadata = crawl_files(path)

def make_request(args):
i, curr_data, metadata_value, live = args
json_data = {
"user_id": user_id,
"collection_name": collection_name,
"data": [curr_data],
"metadata": [metadata_value],
}
if i == 0:
response = requests.post(
VECTORDB_CREATE_ENDPOINT, json=json_data, headers=get_headers(), stream=True
)
else:
response = requests.post(
VECTORDB_UPSERT_ENDPOINT, json=json_data, headers=get_headers(), stream=True
)

if response.status_code == 200:
for chunk in response.iter_lines():
# process line here
live.update(
Panel(
f"Indexing: {chunk.decode()}",
title="[bold green]Indexer[/bold green]",
border_style="green",
)
)

console = Console()
with Live(
Panel(
Expand All @@ -67,31 +97,9 @@ def create_remote_qdrant_db(collection_name, link=None, path=None):
vertical_overflow="visible",
) as live:
if data:
for i, curr_data in enumerate(data):
json_data = {
"user_id": user_id,
"collection_name": collection_name,
"data": [curr_data],
"metadata": [metadata[i]],
}
if i == 0:
response = requests.post(
VECTORDB_CREATE_ENDPOINT, json=json_data, headers=get_headers(), stream=True
)
else:
response = requests.post(
VECTORDB_UPSERT_ENDPOINT, json=json_data, headers=get_headers(), stream=True
)
if response.status_code == 200:
for chunk in response.iter_lines():
# process line here
live.update(
Panel(
f"Indexing: {chunk.decode()}",
title="[bold green]Indexer[/bold green]",
border_style="green",
)
)
with ThreadPoolExecutor(max_workers=10) as executor:
args_list = [(i, curr_data, metadata[i], live) for i, curr_data in enumerate(data)]
list(executor.map(make_request, args_list))
else:
json_data = {
"user_id": user_id,
Expand Down

0 comments on commit 1db0653

Please sign in to comment.