diff --git a/src/llama_stack_client/lib/cli/inference/inference.py b/src/llama_stack_client/lib/cli/inference/inference.py index 7c4562a4..7280ceff 100644 --- a/src/llama_stack_client/lib/cli/inference/inference.py +++ b/src/llama_stack_client/lib/cli/inference/inference.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Optional +from typing import Optional, List, Dict +import traceback import click from rich.console import Console @@ -19,13 +20,20 @@ def inference(): @click.command("chat-completion") -@click.option("--message", required=True, help="Message") +@click.option("--message", help="Message") @click.option("--stream", is_flag=True, help="Streaming", default=False) +@click.option("--session", is_flag=True, help="Start a Chat Session", default=False) @click.option("--model-id", required=False, help="Model ID") @click.pass_context @handle_client_errors("inference chat-completion") -def chat_completion(ctx, message: str, stream: bool, model_id: Optional[str]): +def chat_completion(ctx, message: str, stream: bool, session: bool, model_id: Optional[str]): """Show available inference chat completion endpoints on distribution endpoint""" + if not message and not session: + click.secho( + "you must specify either --message or --session", + fg="red", + ) + raise click.exceptions.Exit(1) client = ctx.obj["client"] console = Console() @@ -33,16 +41,46 @@ def chat_completion(ctx, message: str, stream: bool, model_id: Optional[str]): available_models = [model.identifier for model in client.models.list() if model.model_type == "llm"] model_id = available_models[0] - response = client.inference.chat_completion( - model_id=model_id, - messages=[{"role": "user", "content": message}], - stream=stream, - ) - if not stream: - console.print(response) - else: - for event in EventLogger().log(response): - event.print() + messages = [] + if message: + messages.append({"role": "user", "content": message}) + response = client.inference.chat_completion( + model_id=model_id, + messages=messages, + stream=stream, + ) + if not stream: + console.print(response) + else: + for event in EventLogger().log(response): + event.print() + if session: + chat_session(client=client, model_id=model_id, messages=messages, console=console) + + +def chat_session(client, model_id: Optional[str], messages: List[Dict[str, str]], console: Console): + """Run an interactive chat session with the served model""" + while True: + try: + message = input(">>> ") + if message in ["\\q", "quit"]: + console.print("Exiting") + break + messages.append({"role": "user", "content": message}) + response = client.inference.chat_completion( + model_id=model_id, + messages=messages, + stream=True, + ) + for event in EventLogger().log(response): + event.print() + except Exception as exc: + traceback.print_exc() + console.print(f"Error in chat session {exc}") + break + except KeyboardInterrupt as exc: + console.print("\nDetected user interrupt, exiting") + break # Register subcommands