Skip to content

Commit

Permalink
Merge branch 'main' into azure-fileshare-downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
rchan26 committed Jun 12, 2024
2 parents da0c93a + bb0a9db commit b0d296c
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 23 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ langchain-community = "^0.2.4"
tiktoken = "^0.7.0"
llama-index-embeddings-huggingface = "^0.2.1"
azure-storage-file-share = "^12.16.0"
rich = "^13.7.1"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -106,10 +107,13 @@ build-backend = "poetry.core.masonry.api"
minversion = "6.0"
testpaths = [
"tests",
"reginald",
]
addopts = """
--cov=estios
--cov=reginald
--cov-report=term:skip-covered
--cov-append
--pdbcls=IPython.terminal.debugger:TerminalPdb
--doctest-modules
"""
doctest_optionflags = ["NORMALIZE_WHITESPACE", "ELLIPSIS",]
13 changes: 6 additions & 7 deletions reginald/models/models/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openai import AzureOpenAI, OpenAI

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import get_env_var
from reginald.utils import get_env_var, stream_iter_progress_wrapper


class ChatCompletionBase(ResponseModel):
Expand Down Expand Up @@ -180,9 +180,8 @@ def stream_message(self, message: str, user_id: str) -> None:
stream=True,
)

print("Reginald: ", end="")
for chunk in response:
print(chunk.choices[0].delta.content)
for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)


class ChatCompletionOpenAI(ChatCompletionBase):
Expand Down Expand Up @@ -269,6 +268,6 @@ def stream_message(self, message: str, user_id: str) -> None:
messages=[{"role": "user", "content": message}],
stream=True,
)
print("Reginald: ", end="")
for chunk in response:
print(chunk["choices"][0]["delta"]["content"])

for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)
8 changes: 5 additions & 3 deletions reginald/models/models/hello.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import stream_iter_progress_wrapper


class Hello(ResponseModel):
Expand All @@ -18,6 +19,7 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
return MessageResponse(f"Hello <@{user_id}>")

def stream_message(self, message: str, user_id: str) -> None:
print("\nReginald: ", end="")
for token in ["Hello", "!", " How", " are", " you", "?"]:
print(token, end="")
# print("\nReginald: ", end="")
token_list: tuple[str, ...] = ("Hello", "!", " How", " are", " you", "?")
for token in stream_iter_progress_wrapper(token_list):
print(token, end="", flush=True)
35 changes: 24 additions & 11 deletions reginald/models/models/llama_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.models.models.llama_utils import completion_to_prompt, messages_to_prompt
from reginald.utils import get_env_var
from reginald.utils import (
get_env_var,
stream_iter_progress_wrapper,
stream_progress_wrapper,
)

nest_asyncio.apply()

Expand Down Expand Up @@ -632,17 +636,27 @@ def __init__(
data_dir=self.data_dir,
settings=settings,
)
self.index = data_creator.create_index()
data_creator.save_index()
self.index = stream_progress_wrapper(
data_creator.create_index,
task_str="Generating the index from scratch...",
)
stream_progress_wrapper(
data_creator.save_index,
task_str="Saving the index...",
)

else:
logging.info("Loading the storage context")
storage_context = StorageContext.from_defaults(
persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index
storage_context = stream_progress_wrapper(
StorageContext.from_defaults,
task_str="Loading the storage context...",
persist_dir=self.data_dir / LLAMA_INDEX_DIR / self.which_index,
)

logging.info("Loading the pre-processed index")
self.index = load_index_from_storage(
self.index = stream_progress_wrapper(
load_index_from_storage,
task_str="Loading the pre-processed index...",
storage_context=storage_context,
settings=settings,
)
Expand Down Expand Up @@ -862,19 +876,18 @@ def stream_message(self, message: str, user_id: str) -> None:
self.query_engine._response_synthesizer._streaming = True
response_stream = self.query_engine.query(message)

print("\nReginald: ", end="")
for token in response_stream.response_gen:
print(token, end="")
for token in stream_iter_progress_wrapper(response_stream.response_gen):
print(token, end="", flush=True)

formatted_response = "\n\n\n" + self._format_sources(response_stream)

for token in re.split(r"(\s+)", formatted_response):
print(token, end="")
print(token, end="", flush=True)
except Exception as e: # ignore: broad-except
for token in re.split(
r"(\s+)", self.error_response_template.format(repr(e))
):
print(token, end="")
print(token, end="", flush=True)


class LlamaIndexOllama(LlamaIndex):
Expand Down
2 changes: 2 additions & 0 deletions reginald/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def run_chat_interact(streaming: bool = False, **kwargs) -> ResponseModel:
message = input(">>> ")
if message in ["exit", "exit()", "quit()", "bye Reginald"]:
return response_model
if message == "":
continue
if message in ["clear_history", "\clear_history"]:
if (
response_model.mode == "chat"
Expand Down
85 changes: 85 additions & 0 deletions reginald/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,90 @@
import logging
import os
from itertools import chain
from time import sleep
from typing import Any, Callable, Final, Generator, Iterable

from rich.progress import Progress, SpinnerColumn, TextColumn

REGINAL_PROMPT: Final[str] = "Reginald: "


def stream_iter_progress_wrapper(
streamer: Iterable | Callable | chain,
task_str: str = REGINAL_PROMPT,
progress_bar: bool = True,
end: str = "",
*args,
**kwargs,
) -> Iterable:
"""Add a progress bar for iteration.
Examples
--------
>>> from time import sleep
>>> def sleeper(naps: int = 3) -> Generator[str, None, None]:
... for nap in range(naps):
... sleep(1)
... yield f'nap: {nap}'
>>> tuple(stream_iter_progress_wrapper(streamer=sleeper))
<BLANKLINE>
Reginald: ('nap: 0', 'nap: 1', 'nap: 2')
>>> tuple(stream_iter_progress_wrapper(
... streamer=sleeper, progress_bar=False))
Reginald: ('nap: 0', 'nap: 1', 'nap: 2')
"""
if isinstance(streamer, Callable):
streamer = streamer(*args, **kwargs)
if progress_bar:
with Progress(
TextColumn("{task.description}[progress.description]"),
SpinnerColumn(),
transient=True,
) as progress:
if isinstance(streamer, list | tuple):
streamer = (item for item in streamer)
assert isinstance(streamer, Generator)
progress.add_task(task_str)
first_item = next(streamer)
streamer = chain((first_item,), streamer)
print(task_str, end=end)
return streamer


def stream_progress_wrapper(
streamer: Callable,
task_str: str = REGINAL_PROMPT,
progress_bar: bool = True,
end: str = "\n",
*args,
**kwargs,
) -> Any:
"""Add a progress bar for iteration.
Examples
--------
>>> from time import sleep
>>> def sleeper(seconds: int = 3) -> str:
... sleep(seconds)
... return f'{seconds} seconds nap'
>>> stream_progress_wrapper(sleeper)
<BLANKLINE>
Reginald:
'3 seconds nap'
"""
if progress_bar:
with Progress(
TextColumn("{task.description}[progress.description]"),
SpinnerColumn(),
transient=True,
) as progress:
progress.add_task(task_str)
results: Any = streamer(*args, **kwargs)
print(task_str, end=end)
return results
else:
print(task_str, end=end)
return streamer(*args, **kwargs)


def get_env_var(
Expand Down

0 comments on commit b0d296c

Please sign in to comment.