Skip to content

Commit

Permalink
Merge pull request #1202 from guardrails-ai/fix-watch
Browse files Browse the repository at this point in the history
Fix Watch Mode
  • Loading branch information
zsimjee authored Dec 24, 2024
2 parents e6ad48e + aab40b7 commit 078200e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
12 changes: 10 additions & 2 deletions guardrails/cli/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from guardrails.cli.telemetry import trace_if_enabled
from guardrails.cli.version import version_warnings_if_applicable
from guardrails.cli.hub.console import console
from guardrails.settings import settings


def api_is_installed() -> bool:
Expand All @@ -32,15 +33,22 @@ def start(
default=8000,
help="The port to run the server on.",
),
watch: bool = typer.Option(
default=False, is_flag=True, help="Enable watch mode for logs."
),
):
logger.debug("Checking for prerequisites...")
if not api_is_installed():
package_name = 'guardrails-api>="^0.0.0a0"'
pip_process("install", package_name)

from guardrails_api.cli.start import start # type: ignore
from guardrails_api.cli.start import start as start_api # type: ignore

logger.info("Starting Guardrails server")

if watch:
settings._watch_mode_enabled = True

version_warnings_if_applicable(console)
trace_if_enabled("start")
start(env, config, port)
start_api(env, config, port)
2 changes: 2 additions & 0 deletions guardrails/cli/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import rich
import typer

from guardrails.settings import settings
from guardrails.cli.guardrails import guardrails as gr_cli
from guardrails.call_tracing import GuardTraceEntry, TraceHandler
from guardrails.cli.telemetry import trace_if_enabled
Expand All @@ -31,6 +32,7 @@ def watch_command(
default=False, is_flag=True, help="Clear all log outputs and exit."
),
):
settings._watch_mode_enabled = True
trace_if_enabled("watch")
if clear:
_clear_and_quit()
Expand Down
6 changes: 6 additions & 0 deletions guardrails/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class Settings:
_instance = None
_lock = threading.Lock()
_rc: RC
_watch_mode_enabled: bool
"""Whether to use a local server for running Guardrails."""
use_server: Optional[bool]
"""Whether to disable tracing.
Expand All @@ -29,6 +30,7 @@ def _initialize(self):
self.use_server = None
self.disable_tracing = None
self._rc = RC.load()
self._watch_mode_enabled = False

@property
def rc(self) -> RC:
Expand All @@ -40,5 +42,9 @@ def rc(self) -> RC:
def rc(self, value: RC):
self._rc = value

@property
def watch_mode_enabled(self) -> bool:
return self._watch_mode_enabled


settings = Settings()
6 changes: 4 additions & 2 deletions guardrails/telemetry/legacy_validator_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from guardrails.actions.refrain import Refrain
from guardrails.call_tracing.trace_handler import TraceHandler
from guardrails.classes.validation.validator_logs import ValidatorLogs
from guardrails.settings import settings
from guardrails.telemetry.common import get_span
from guardrails.utils.casting_utils import to_string

Expand Down Expand Up @@ -68,7 +69,8 @@ def trace_validator_result(
**kwargs,
}

TraceHandler().log_validator(validator_log)
if settings.watch_mode_enabled:
TraceHandler().log_validator(validator_log)

current_span.add_event(
f"{validator_name}_result",
Expand All @@ -85,6 +87,6 @@ def trace_validation_result(
current_span=None,
):
_current_span = get_span(current_span)
if _current_span is not None:
if _current_span is not None and not settings.disable_tracing:
for log in validation_logs:
trace_validator_result(_current_span, log, attempt_number)

0 comments on commit 078200e

Please sign in to comment.