diff --git a/private_gpt/ui/ui.py b/private_gpt/ui/ui.py index 04258ac39..7f96a4f2e 100644 --- a/private_gpt/ui/ui.py +++ b/private_gpt/ui/ui.py @@ -51,6 +51,14 @@ class Modes(str, Enum): ] +class Styles(str, Enum): + STREAMING = "Streaming" + NON_STREAMING = "Non-Streaming" + + +STYLES: list[Styles] = [Styles.STREAMING, Styles.NON_STREAMING] + + class Source(BaseModel): file: str page: str @@ -105,6 +113,9 @@ def __init__( ) self._system_prompt = self._get_default_system_prompt(self._default_mode) + # Initialize default response style: Streaming + self.response_style = STYLES[0] + def _chat( self, message: str, history: list[list[str]], mode: Modes, *_: Any ) -> Any: @@ -185,18 +196,30 @@ def build_history() -> list[ChatMessage]: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - query_stream = self._chat_service.stream_chat( - messages=all_messages, - use_context=True, - context_filter=context_filter, - ) - yield from yield_deltas(query_stream) + match self.response_style: + case Styles.STREAMING: + query_stream = self._chat_service.stream_chat( + all_messages, use_context=False + ) + yield from yield_deltas(query_stream) + case Styles.NON_STREAMING: + query_response = self._chat_service.chat( + all_messages, use_context=False + ).response + yield from [query_response] + case Modes.BASIC_CHAT_MODE: - llm_stream = self._chat_service.stream_chat( - messages=all_messages, - use_context=False, - ) - yield from yield_deltas(llm_stream) + match self.response_style: + case Styles.STREAMING: + llm_stream = self._chat_service.stream_chat( + all_messages, use_context=False + ) + yield from yield_deltas(llm_stream) + case Styles.NON_STREAMING: + llm_response = self._chat_service.chat( + all_messages, use_context=False + ).response + yield from [llm_response] case Modes.SEARCH_MODE: response = self._chunks_service.retrieve_relevant( @@ -224,12 +247,21 @@ def build_history() -> list[ChatMessage]: docs_ids.append(ingested_document.doc_id) context_filter = ContextFilter(docs_ids=docs_ids) - summary_stream = self._summarize_service.stream_summarize( - use_context=True, - context_filter=context_filter, - instructions=message, - ) - yield from yield_tokens(summary_stream) + match self.response_style: + case Styles.STREAMING: + summary_stream = self._summarize_service.stream_summarize( + use_context=True, + context_filter=context_filter, + instructions=message, + ) + yield from yield_tokens(summary_stream) + case Styles.NON_STREAMING: + summary_response = self._summarize_service.summarize( + use_context=True, + context_filter=context_filter, + instructions=message, + ) + yield from summary_response # On initialization and on mode change, this function set the system prompt # to the default prompt based on the mode (and user settings). @@ -282,6 +314,9 @@ def _set_current_mode(self, mode: Modes) -> Any: gr.update(value=self._explanation_mode), ] + def _set_current_response_style(self, response_style: Styles) -> Any: + self.response_style = response_style + def _list_ingested_files(self) -> list[list[str]]: files = set() for ingested_document in self._ingest_service.list_ingested(): @@ -405,6 +440,15 @@ def _build_ui_blocks(self) -> gr.Blocks: max_lines=3, interactive=False, ) + default_response_style = STYLES[0] + response_style = ( + gr.Dropdown( + [response_style.value for response_style in STYLES], + label="Response Style", + value=default_response_style, + interactive=True, + ), + ) upload_button = gr.components.UploadButton( "Upload File(s)", type="filepath", @@ -498,6 +542,10 @@ def _build_ui_blocks(self) -> gr.Blocks: self._set_system_prompt, inputs=system_prompt_input, ) + # When response style changes + response_style[0].change( + self._set_current_response_style, inputs=response_style + ) def get_model_label() -> str | None: """Get model label from llm mode setting YAML.