diff --git a/docs/help.md b/docs/help.md index 6813ef7af..2ca080b5a 100644 --- a/docs/help.md +++ b/docs/help.md @@ -99,6 +99,7 @@ Options: --cid, --conversation TEXT Continue the conversation with the given ID. --key TEXT API key to use --save TEXT Save prompt with this template name + -r, --rich Format output as rich markdown text --help Show this message and exit. ``` @@ -119,6 +120,7 @@ Options: -o, --option ... key/value options for the model --no-stream Do not stream output --key TEXT API key to use + -r, --rich Format output as rich markdown text --help Show this message and exit. ``` diff --git a/llm/cli.py b/llm/cli.py index 8e649a6d9..06a2cd3fe 100644 --- a/llm/cli.py +++ b/llm/cli.py @@ -28,6 +28,9 @@ import base64 import pathlib import pydantic +from rich.console import Console +from rich.live import Live +from rich.markdown import Markdown from runpy import run_module import shutil import sqlite_utils @@ -45,6 +48,8 @@ DEFAULT_TEMPLATE = "prompt: " +console = Console() + def _validate_metadata_json(ctx, param, value): if value is None: @@ -121,6 +126,13 @@ def cli(): ) @click.option("--key", help="API key to use") @click.option("--save", help="Save prompt with this template name") +@click.option( + "--rich", + "-r", + is_flag=True, + default=False, + help="Format output as rich markdown text", +) def prompt( prompt, system, @@ -135,6 +147,7 @@ def prompt( conversation_id, key, save, + rich, ): """ Execute a prompt @@ -272,13 +285,7 @@ def read_prompt(): try: response = prompt_method(prompt, system, **validated_options) - if should_stream: - for chunk in response: - print(chunk, end="") - sys.stdout.flush() - print("") - else: - print(response.text()) + print_response(response=response, stream=should_stream, rich=rich) except Exception as ex: raise click.ClickException(str(ex)) @@ -326,6 +333,13 @@ def read_prompt(): ) @click.option("--no-stream", is_flag=True, help="Do not stream output") @click.option("--key", help="API key to use") +@click.option( + "--rich", + "-r", + is_flag=True, + default=False, + help="Format output as rich markdown text", +) def chat( system, model_id, @@ -336,6 +350,7 @@ def chat( options, no_stream, key, + rich, ): """ Hold an ongoing chat with a model. @@ -435,11 +450,9 @@ def chat( response = conversation.prompt(prompt, system, **validated_options) # System prompt only sent for the first message: system = None - for chunk in response: - print(chunk, end="") - sys.stdout.flush() + print_response(response=response, stream=True, rich=rich) response.log_to_db(db) - print("") + console.print("") def load_conversation(conversation_id: Optional[str]) -> Optional[Conversation]: @@ -1641,3 +1654,19 @@ def _human_readable_size(size_bytes): def logs_on(): return not (user_dir() / "logs-off").exists() + + +def print_response(response: Response, stream: bool = True, rich: bool = False): + if stream is True and rich is False: + for chunk in response: + console.print(chunk, end="") + elif stream is True and rich is True: + md = "" + with Live(Markdown(""), console=console) as live: + for chunk in response: + md += chunk + live.update(Markdown(md)) + elif stream is False and rich is True: + console.print(Markdown(response.text())) + else: + console.print(response.text()) diff --git a/setup.py b/setup.py index 8025e68e0..168ee8976 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ def get_long_description(): "python-ulid", "setuptools", "pip", + "rich", ], extras_require={ "test": [