Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming choice feature #2070

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
82 changes: 65 additions & 17 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class Modes(str, Enum):
]


class Styles(str, Enum):
STREAMING = "Streaming"
NON_STREAMING = "Non-Streaming"


STYLES: list[Styles] = [Styles.STREAMING, Styles.NON_STREAMING]


class Source(BaseModel):
file: str
page: str
Expand Down Expand Up @@ -105,6 +113,9 @@ def __init__(
)
self._system_prompt = self._get_default_system_prompt(self._default_mode)

# Initialize default response style: Streaming
self.response_style = STYLES[0]

def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
Expand Down Expand Up @@ -185,18 +196,30 @@ def build_history() -> list[ChatMessage]:
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)

query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
context_filter=context_filter,
)
yield from yield_deltas(query_stream)
match self.response_style:
case Styles.STREAMING:
query_stream = self._chat_service.stream_chat(
all_messages, use_context=False
)
yield from yield_deltas(query_stream)
case Styles.NON_STREAMING:
query_response = self._chat_service.chat(
all_messages, use_context=False
).response
yield from [query_response]

case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
match self.response_style:
case Styles.STREAMING:
llm_stream = self._chat_service.stream_chat(
all_messages, use_context=False
)
yield from yield_deltas(llm_stream)
case Styles.NON_STREAMING:
llm_response = self._chat_service.chat(
all_messages, use_context=False
).response
yield from [llm_response]

case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant(
Expand Down Expand Up @@ -224,12 +247,21 @@ def build_history() -> list[ChatMessage]:
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)

summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
match self.response_style:
case Styles.STREAMING:
summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
case Styles.NON_STREAMING:
summary_response = self._summarize_service.summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from summary_response

# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
Expand Down Expand Up @@ -282,6 +314,9 @@ def _set_current_mode(self, mode: Modes) -> Any:
gr.update(value=self._explanation_mode),
]

def _set_current_response_style(self, response_style: Styles) -> Any:
self.response_style = response_style

def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
Expand Down Expand Up @@ -405,6 +440,15 @@ def _build_ui_blocks(self) -> gr.Blocks:
max_lines=3,
interactive=False,
)
default_response_style = STYLES[0]
response_style = (
gr.Dropdown(
[response_style.value for response_style in STYLES],
label="Response Style",
value=default_response_style,
interactive=True,
),
)
upload_button = gr.components.UploadButton(
"Upload File(s)",
type="filepath",
Expand Down Expand Up @@ -498,6 +542,10 @@ def _build_ui_blocks(self) -> gr.Blocks:
self._set_system_prompt,
inputs=system_prompt_input,
)
# When response style changes
response_style[0].change(
self._set_current_response_style, inputs=response_style
)

def get_model_label() -> str | None:
"""Get model label from llm mode setting YAML.
Expand Down
Loading