Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def init(self):
self._init_flows_index(),
)

def _extract_user_message_example(self, flow: Flow):
def _extract_user_message_example(self, flow: Flow) -> None:
"""Heuristic to extract user message examples from a flow."""
elements = [
item
Expand Down
52 changes: 35 additions & 17 deletions nemoguardrails/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import logging
import os
from typing import List, Optional
from enum import Enum
from typing import Any, List, Literal, Optional

import typer
import uvicorn
Expand All @@ -27,13 +28,24 @@
from nemoguardrails.cli.chat import run_chat
from nemoguardrails.cli.migration import migrate
from nemoguardrails.cli.providers import _list_providers, select_provider_with_type
from nemoguardrails.eval import cli
from nemoguardrails.eval import cli as eval_cli
from nemoguardrails.logging.verbose import set_verbose
from nemoguardrails.utils import init_random_seed


class ColangVersions(str, Enum):
one = "1.0"
two_alpha = "2.0-alpha"


_COLANG_VERSIONS = [version.value for version in ColangVersions]


app = typer.Typer()

app.add_typer(cli.app, name="eval", short_help="Evaluation a guardrail configuration.")
app.add_typer(
eval_cli.app, name="eval", short_help="Evaluation a guardrail configuration."
)
app.pretty_exceptions_enable = False

logging.getLogger().setLevel(logging.WARNING)
Expand All @@ -44,7 +56,8 @@ def chat(
config: List[str] = typer.Option(
default=["config"],
exists=True,
help="Path to a directory containing configuration files to use. Can also point to a single configuration file.",
help="Path to a directory containing configuration files to use. "
"Can also point to a single configuration file.",
),
verbose: bool = typer.Option(
default=False,
Expand All @@ -60,7 +73,8 @@ def chat(
),
debug_level: List[str] = typer.Option(
default=[],
help="Enable debug mode which prints rich information about the flows execution. Available levels: WARNING, INFO, DEBUG",
help="Enable debug mode which prints rich information about the flows execution. "
"Available levels: WARNING, INFO, DEBUG",
),
streaming: bool = typer.Option(
default=False,
Expand All @@ -77,7 +91,7 @@ def chat(
):
"""Start an interactive chat session."""
if len(config) > 1:
typer.secho(f"Multiple configurations are not supported.", fg=typer.colors.RED)
typer.secho("Multiple configurations are not supported.", fg=typer.colors.RED)
typer.echo("Please provide a single folder.")
raise typer.Exit(1)

Expand Down Expand Up @@ -143,23 +157,27 @@ def server(
if config:
# We make sure there is no trailing separator, as that might break things in
# single config mode.
api.app.rails_config_path = os.path.expanduser(config[0].rstrip(os.path.sep))
setattr(
api.app,
"rails_config_path",
os.path.expanduser(config[0].rstrip(os.path.sep)),
)
else:
# If we don't have a config, we try to see if there is a local config folder
local_path = os.getcwd()
local_configs_path = os.path.join(local_path, "config")

if os.path.exists(local_configs_path):
api.app.rails_config_path = local_configs_path
setattr(api.app, "rails_config_path", local_configs_path)

if verbose:
logging.getLogger().setLevel(logging.INFO)

if disable_chat_ui:
api.app.disable_chat_ui = True
setattr(api.app, "disable_chat_ui", True)

if auto_reload:
api.app.auto_reload = True
setattr(api.app, "auto_reload", True)

if prefix:
server_app = FastAPI()
Expand All @@ -173,17 +191,14 @@ def server(
uvicorn.run(server_app, port=port, log_level="info", host="0.0.0.0")


_AVAILABLE_OPTIONS = ["1.0", "2.0-alpha"]


@app.command()
def convert(
path: str = typer.Argument(
..., help="The path to the file or directory to migrate."
),
from_version: str = typer.Option(
default="1.0",
help=f"The version of the colang files to migrate from. Available options: {_AVAILABLE_OPTIONS}.",
from_version: ColangVersions = typer.Option(
default=ColangVersions.one,
help=f"The version of the colang files to migrate from. Available options: {_COLANG_VERSIONS}.",
),
verbose: bool = typer.Option(
default=False,
Expand All @@ -209,11 +224,14 @@ def convert(

absolute_path = os.path.abspath(path)

# Typer CLI args have to use an enum, not literal. Convert to Literal here
from_version_literal: Literal["1.0", "2.0-alpha"] = from_version.value

migrate(
path=absolute_path,
include_main_flow=include_main_flow,
use_active_decorator=use_active_decorator,
from_version=from_version,
from_version=from_version_literal,
validate=validate,
)

Expand Down
105 changes: 79 additions & 26 deletions nemoguardrails/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import asyncio
import json
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, cast
from dataclasses import asdict, dataclass, field
from typing import Dict, List, Optional, Tuple, Union, cast

import aiohttp
from prompt_toolkit import HTML, PromptSession
Expand All @@ -30,7 +30,11 @@
from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x
from nemoguardrails.logging import verbose
from nemoguardrails.logging.verbose import console
from nemoguardrails.streaming import StreamingHandler
from nemoguardrails.rails.llm.options import (
GenerationLog,
GenerationOptions,
GenerationResponse,
)
from nemoguardrails.utils import get_or_create_event_loop, new_event_dict, new_uuid

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down Expand Up @@ -61,6 +65,8 @@ async def _run_chat_v1_0(
)

if not server_url:
if config_path is None:
raise RuntimeError("config_path cannot be None when server_url is None")
rails_config = RailsConfig.from_path(config_path)
rails_app = LLMRails(rails_config, verbose=verbose)
if streaming and not rails_config.streaming_supported:
Expand All @@ -82,7 +88,12 @@ async def _run_chat_v1_0(

if not server_url:
# If we have streaming from a locally loaded config, we initialize the handler.
if streaming and not server_url and rails_app.main_llm_supports_streaming:
if (
streaming
and not server_url
and rails_app
and rails_app.main_llm_supports_streaming
):
bot_message_list = []
async for chunk in rails_app.stream_async(messages=history):
if '{"event": "ABORT"' in chunk:
Expand All @@ -101,11 +112,40 @@ async def _run_chat_v1_0(
bot_message = {"role": "assistant", "content": bot_message_text}

else:
bot_message = await rails_app.generate_async(messages=history)
if rails_app is None:
raise RuntimeError("Rails App is None")
response: Union[
str, Dict, GenerationResponse, Tuple[Dict, Dict]
] = await rails_app.generate_async(messages=history)

# Handle different return types from generate_async
if isinstance(response, tuple) and len(response) == 2:
bot_message = (
response[0]
if response
else {"role": "assistant", "content": ""}
)
elif isinstance(response, GenerationResponse):
# GenerationResponse case
response_attr = getattr(response, "response", None)
if isinstance(response_attr, list) and len(response_attr) > 0:
bot_message = response_attr[0]
else:
bot_message = {
"role": "assistant",
"content": str(response_attr),
}
elif isinstance(response, dict):
# Direct dict case
bot_message = response
else:
# String or other fallback case
bot_message = {"role": "assistant", "content": str(response)}

if not streaming or not rails_app.main_llm_supports_streaming:
# We print bot messages in green.
console.print("[green]" + f"{bot_message['content']}" + "[/]")
content = bot_message.get("content", str(bot_message))
console.print("[green]" + f"{content}" + "[/]")
else:
data = {
"config_id": config_id,
Expand All @@ -116,19 +156,19 @@ async def _run_chat_v1_0(
async with session.post(
f"{server_url}/v1/chat/completions",
json=data,
) as response:
) as http_response:
# If the response is streaming, we show each chunk as it comes
if response.headers.get("Transfer-Encoding") == "chunked":
if http_response.headers.get("Transfer-Encoding") == "chunked":
bot_message_text = ""
async for chunk in response.content.iter_any():
chunk = chunk.decode("utf-8")
async for chunk_bytes in http_response.content.iter_any():
chunk = chunk_bytes.decode("utf-8")
console.print("[green]" + f"{chunk}" + "[/]", end="")
bot_message_text += chunk
console.print("")

bot_message = {"role": "assistant", "content": bot_message_text}
else:
result = await response.json()
result = await http_response.json()
bot_message = result["messages"][0]

# We print bot messages in green.
Expand Down Expand Up @@ -297,7 +337,8 @@ def _process_output():
else:
console.print(
"[black on magenta]"
+ f"scene information (start): (title={event['title']}, action_uid={event['action_uid']}, content={event['content']})"
+ f"scene information (start): (title={event['title']}, "
+ f"action_uid={event['action_uid']}, content={event['content']})"
+ "[/]"
)

Expand Down Expand Up @@ -333,7 +374,8 @@ def _process_output():
else:
console.print(
"[black on magenta]"
+ f"scene form (start): (prompt={event['prompt']}, action_uid={event['action_uid']}, inputs={event['inputs']})"
+ f"scene form (start): (prompt={event['prompt']}, "
+ f"action_uid={event['action_uid']}, inputs={event['inputs']})"
+ "[/]"
)
chat_state.input_events.append(
Expand Down Expand Up @@ -370,7 +412,8 @@ def _process_output():
else:
console.print(
"[black on magenta]"
+ f"scene choice (start): (prompt={event['prompt']}, action_uid={event['action_uid']}, options={event['options']})"
+ f"scene choice (start): (prompt={event['prompt']}, "
+ f"action_uid={event['action_uid']}, options={event['options']})"
+ "[/]"
)
chat_state.input_events.append(
Expand Down Expand Up @@ -452,12 +495,16 @@ async def _check_local_async_actions():
# We need to copy input events to prevent race condition
input_events_copy = chat_state.input_events.copy()
chat_state.input_events = []
(
chat_state.output_events,
chat_state.output_state,
) = await rails_app.process_events_async(
input_events_copy, chat_state.state

output_events, output_state = await rails_app.process_events_async(
input_events_copy,
asdict(chat_state.state) if chat_state.state else None,
)
chat_state.output_events = output_events

# process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
if output_state:
chat_state.output_state = cast(State, State(**output_state))

# Process output_events and potentially generate new input_events
_process_output()
Expand All @@ -470,7 +517,8 @@ async def _check_local_async_actions():
# If there are no pending actions, we stop
check_task.cancel()
check_task = None
debugger.set_output_state(chat_state.output_state)
if chat_state.output_state is not None:
debugger.set_output_state(chat_state.output_state)
chat_state.status.stop()
enable_input.set()
return
Expand All @@ -485,13 +533,16 @@ async def _process_input_events():
# We need to copy input events to prevent race condition
input_events_copy = chat_state.input_events.copy()
chat_state.input_events = []
(
chat_state.output_events,
chat_state.output_state,
) = await rails_app.process_events_async(
input_events_copy, chat_state.state
output_events, output_state = await rails_app.process_events_async(
input_events_copy,
asdict(chat_state.state) if chat_state.state else None,
)
debugger.set_output_state(chat_state.output_state)
chat_state.output_events = output_events
if output_state:
# process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
output_state_typed: State = cast(State, State(**output_state))
chat_state.output_state = output_state_typed
debugger.set_output_state(output_state_typed)

_process_output()
# If we don't have a check task, we start it
Expand Down Expand Up @@ -653,6 +704,8 @@ def run_chat(
server_url (Optional[str]): The URL of the chat server. Defaults to None.
config_id (Optional[str]): The configuration ID. Defaults to None.
"""
if config_path is None:
raise RuntimeError("config_path cannot be None")
rails_config = RailsConfig.from_path(config_path)

if verbose and verbose_llm_calls:
Expand Down
Loading
Loading