diff --git a/examples/tshirt_store_agent/__main__.py b/examples/tshirt_store_agent/__main__.py index 9e0b7715..2a8296c5 100644 --- a/examples/tshirt_store_agent/__main__.py +++ b/examples/tshirt_store_agent/__main__.py @@ -5,16 +5,8 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore -from a2a.types import ( - AgentCapabilities, - AgentCard, - AgentSkill, -) +from a2a.types import AgentCapabilities, AgentCard, AgentSkill from dotenv import load_dotenv -from google.adk.artifacts import InMemoryArtifactService -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService -from google.adk.runners import Runner -from google.adk.sessions import InMemorySessionService from .tshirt_store_agent import create_tshirt_store_agent from .tshirt_store_agent_executor import TShirtStoreAgentExecutor @@ -28,6 +20,12 @@ @click.option("--host", "host", default="localhost") @click.option("--port", "port", default=10001) def main(host: str, port: int) -> None: + # adk imports take a while, importing them here to reduce rogue startup time. + from google.adk.artifacts import InMemoryArtifactService + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + from google.adk.runners import Runner + from google.adk.sessions import InMemorySessionService + skill = AgentSkill( id="sell_tshirt", name="Sell T-Shirt", diff --git a/examples/tshirt_store_agent/tshirt_store_agent.py b/examples/tshirt_store_agent/tshirt_store_agent.py index 827b9478..96e6c153 100644 --- a/examples/tshirt_store_agent/tshirt_store_agent.py +++ b/examples/tshirt_store_agent/tshirt_store_agent.py @@ -1,8 +1,9 @@ import os +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from google.adk.agents import LlmAgent -from google.adk.agents import LlmAgent -from google.adk.models.lite_llm import LiteLlm -from google.adk.tools import FunctionTool AGENT_INSTRUCTIONS = """ You are an agent for a t-shirt store named Shirtify. @@ -72,7 +73,12 @@ def send_email_tool( return f"Email sent to {email} with subject {subject} and body {body}" -def create_tshirt_store_agent() -> LlmAgent: +def create_tshirt_store_agent() -> "LlmAgent": + # adk imports take a while, importing them here to reduce rogue startup time. + from google.adk.agents import LlmAgent + from google.adk.models.lite_llm import LiteLlm + from google.adk.tools import FunctionTool + tools: list[FunctionTool] = [ FunctionTool( func=inventory_tool, diff --git a/examples/tshirt_store_agent/tshirt_store_agent_executor.py b/examples/tshirt_store_agent/tshirt_store_agent_executor.py index 32cb739a..e633adae 100644 --- a/examples/tshirt_store_agent/tshirt_store_agent_executor.py +++ b/examples/tshirt_store_agent/tshirt_store_agent_executor.py @@ -1,32 +1,35 @@ import base64 from logging import getLogger -from typing import AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import EventQueue from a2a.server.tasks import TaskUpdater from a2a.types import ( AgentCard, - TaskState, - UnsupportedOperationError, - Part, - TextPart, FilePart, - FileWithUri, FileWithBytes, + FileWithUri, + Part, + TaskState, + TextPart, + UnsupportedOperationError, ) from a2a.utils.errors import ServerError -from google.adk import Runner -from google.adk.events import Event from google.genai import types +if TYPE_CHECKING: + from google.adk import Runner + from google.adk.events import Event + + logger = getLogger(__name__) class TShirtStoreAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK-based Agent.""" - def __init__(self, runner: Runner, card: AgentCard): + def __init__(self, runner: "Runner", card: AgentCard): self.runner = runner self._card = card @@ -34,9 +37,9 @@ def __init__(self, runner: Runner, card: AgentCard): def _run_agent( self, - session_id, + session_id: str, new_message: types.Content, - ) -> AsyncGenerator[Event, None]: + ) -> AsyncGenerator["Event", None]: return self.runner.run_async( session_id=session_id, user_id="self", diff --git a/rogue/common/__init__.py b/rogue/common/__init__.py index d26a7ee6..723c9176 100644 --- a/rogue/common/__init__.py +++ b/rogue/common/__init__.py @@ -1,9 +1,11 @@ from . import ( agent_model_wrapper, + agent_sessions, generic_agent_executor, generic_task_callback, logging, remote_agent_connection, tui_installer, + update_checker, workdir_utils, ) diff --git a/rogue/common/agent_model_wrapper.py b/rogue/common/agent_model_wrapper.py index 0b8d7955..4f69d128 100644 --- a/rogue/common/agent_model_wrapper.py +++ b/rogue/common/agent_model_wrapper.py @@ -1,16 +1,21 @@ from functools import lru_cache -from typing import Optional +from typing import TYPE_CHECKING, Optional -from google.adk.models import LLMRegistry, BaseLlm -from google.adk.models.lite_llm import LiteLlm from loguru import logger +if TYPE_CHECKING: + from google.adk.models import BaseLlm + @lru_cache() def get_llm_from_model( model: str, llm_auth: Optional[str] = None, -) -> BaseLlm: +) -> "BaseLlm": + # adk imports take a while, importing them here to reduce rogue startup time. + from google.adk.models import LLMRegistry + from google.adk.models.lite_llm import LiteLlm + try: llm_cls = LLMRegistry.resolve(model) except ValueError: diff --git a/rogue/common/agent_sessions.py b/rogue/common/agent_sessions.py index b928046c..f450890f 100644 --- a/rogue/common/agent_sessions.py +++ b/rogue/common/agent_sessions.py @@ -1,13 +1,15 @@ +from typing import TYPE_CHECKING from uuid import uuid4 -from google.adk.sessions import Session, BaseSessionService +if TYPE_CHECKING: + from google.adk.sessions import BaseSessionService, Session async def create_session( app_name: str, - session_service: BaseSessionService, + session_service: "BaseSessionService", user_id: str | None = None, -) -> Session: +) -> "Session": user_id = user_id or uuid4().hex return await session_service.create_session( app_name=app_name, diff --git a/rogue/common/generic_agent_executor.py b/rogue/common/generic_agent_executor.py index 059dc77a..9742abdc 100644 --- a/rogue/common/generic_agent_executor.py +++ b/rogue/common/generic_agent_executor.py @@ -1,30 +1,34 @@ import base64 -from typing import AsyncGenerator +from typing import TYPE_CHECKING, AsyncGenerator -from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.events import EventQueue +from a2a.server.agent_execution import AgentExecutor from a2a.server.tasks import TaskUpdater from a2a.types import ( AgentCard, - TaskState, - UnsupportedOperationError, - Part, - TextPart, FilePart, - FileWithUri, FileWithBytes, + FileWithUri, + Part, + TaskState, + TextPart, + UnsupportedOperationError, ) from a2a.utils.errors import ServerError -from google.adk import Runner -from google.adk.events import Event -from google.genai import types from loguru import logger +if TYPE_CHECKING: + from a2a.server.agent_execution import RequestContext + from a2a.server.events import EventQueue + from google.adk import Runner + from google.adk.events import Event + from google.genai.types import Content + from google.genai.types import Part as GenAIPart + class GenericAgentExecutor(AgentExecutor): """An AgentExecutor that runs an ADK-based Agent.""" - def __init__(self, runner: Runner, card: AgentCard): + def __init__(self, runner: "Runner", card: AgentCard): self.runner = runner self._card = card @@ -32,9 +36,9 @@ def __init__(self, runner: Runner, card: AgentCard): def _run_agent( self, - session_id, - new_message: types.Content, - ) -> AsyncGenerator[Event, None]: + session_id: str, + new_message: "Content", + ) -> AsyncGenerator["Event", None]: return self.runner.run_async( session_id=session_id, user_id="self", @@ -43,7 +47,7 @@ def _run_agent( async def _process_request( self, - new_message: types.Content, + new_message: "Content", session_id: str, task_updater: TaskUpdater, ) -> None: @@ -79,9 +83,13 @@ async def _process_request( async def execute( self, - context: RequestContext, - event_queue: EventQueue, + context: "RequestContext", + event_queue: "EventQueue", ): + # google.genai imports take a while, + # importing them here to reduce rogue startup time. + from google.genai.types import UserContent + # Run the agent until either complete or the task is suspended. updater = TaskUpdater( event_queue, @@ -95,7 +103,7 @@ async def execute( if context.message is not None: await self._process_request( - types.UserContent( + UserContent( parts=convert_a2a_parts_to_genai(context.message.parts), ), context.context_id or "", @@ -103,7 +111,7 @@ async def execute( ) logger.debug("EvaluatorAgentExecutor execute exiting") - async def cancel(self, context: RequestContext, event_queue: EventQueue): + async def cancel(self, context: "RequestContext", event_queue: "EventQueue"): # Ideally: kill any ongoing tasks. raise ServerError(error=UnsupportedOperationError()) @@ -136,27 +144,32 @@ async def _upsert_session(self, session_id: str): return session -def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]: +def convert_a2a_parts_to_genai(parts: list[Part]) -> list["GenAIPart"]: """Convert a list of A2A Part types into a list of Google Gen AI Part types.""" return [convert_a2a_part_to_genai(part) for part in parts] -def convert_a2a_part_to_genai(part: Part) -> types.Part: +def convert_a2a_part_to_genai(part: Part) -> "GenAIPart": + # google.genai imports take a while, + # importing them here to reduce rogue startup time. + from google.genai.types import Blob, FileData + from google.genai.types import Part as GenAIPart + """Convert a single A2A Part type into a Google Gen AI Part type.""" part = part.root # type: ignore if isinstance(part, TextPart): - return types.Part(text=part.text) + return GenAIPart(text=part.text) if isinstance(part, FilePart): if isinstance(part.file, FileWithUri): - return types.Part( - file_data=types.FileData( + return GenAIPart( + file_data=FileData( file_uri=part.file.uri, mime_type=part.file.mimeType, ), ) if isinstance(part.file, FileWithBytes): - return types.Part( - inline_data=types.Blob( + return GenAIPart( + inline_data=Blob( data=base64.b64decode(part.file.bytes), mime_type=part.file.mimeType, ), @@ -165,7 +178,7 @@ def convert_a2a_part_to_genai(part: Part) -> types.Part: raise ValueError(f"Unsupported part type: {type(part)}") -def convert_genai_parts_to_a2a(parts: list[types.Part] | None) -> list[Part]: +def convert_genai_parts_to_a2a(parts: list["GenAIPart"] | None) -> list[Part]: """Convert a list of Google Gen AI Part types into a list of A2A Part types.""" parts = parts or [] return [ @@ -175,7 +188,7 @@ def convert_genai_parts_to_a2a(parts: list[types.Part] | None) -> list[Part]: ] -def convert_genai_part_to_a2a(part: types.Part) -> Part: +def convert_genai_part_to_a2a(part: "GenAIPart") -> Part: """Convert a single Google Gen AI Part type into an A2A Part type.""" if part.text: return Part(root=TextPart(text=part.text)) diff --git a/rogue/evaluator_agent/__init__.py b/rogue/evaluator_agent/__init__.py index 09289414..3a238c7e 100644 --- a/rogue/evaluator_agent/__init__.py +++ b/rogue/evaluator_agent/__init__.py @@ -1,2 +1 @@ -from . import evaluator_agent -from . import run_evaluator_agent +from . import evaluator_agent, policy_evaluation, run_evaluator_agent diff --git a/rogue/evaluator_agent/evaluator_agent.py b/rogue/evaluator_agent/evaluator_agent.py index 7fade583..3738b88d 100644 --- a/rogue/evaluator_agent/evaluator_agent.py +++ b/rogue/evaluator_agent/evaluator_agent.py @@ -1,13 +1,9 @@ import json -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import uuid4 from a2a.client import A2ACardResolver from a2a.types import Message, MessageSendParams, Part, Role, Task, TextPart -from google.adk.agents import LlmAgent -from google.adk.agents.callback_context import CallbackContext -from google.adk.models import LlmRequest, LlmResponse -from google.adk.tools import BaseTool, FunctionTool, ToolContext from google.genai import types from httpx import AsyncClient from loguru import logger @@ -31,6 +27,13 @@ ) from ..evaluator_agent.policy_evaluation import evaluate_policy +if TYPE_CHECKING: + from google.adk.agents import LlmAgent + from google.adk.agents.callback_context import CallbackContext + from google.adk.models import LlmRequest, LlmResponse + from google.adk.tools import BaseTool, ToolContext + + FAST_MODE_AGENT_INSTRUCTIONS = """ You are a scenario tester agent. Your task is to test the given scenarios against another agent and evaluate whether that agent passes or fails each test scenario. @@ -196,7 +199,11 @@ async def _get_evaluated_agent_client(self) -> RemoteAgentConnections: return self.__evaluated_agent_client - def get_underlying_agent(self) -> LlmAgent: + def get_underlying_agent(self) -> "LlmAgent": + # adk imports take a while, importing them here to reduce rogue startup time. + from google.adk.agents import LlmAgent + from google.adk.tools import FunctionTool + instructions_template = ( AGENT_INSTRUCTIONS if self._deep_test_mode else FAST_MODE_AGENT_INSTRUCTIONS ) @@ -251,9 +258,9 @@ def get_underlying_agent(self) -> LlmAgent: def _before_tool_callback( self, - tool: BaseTool, + tool: "BaseTool", args: dict[str, Any], - tool_context: ToolContext, + tool_context: "ToolContext", ) -> Optional[dict]: # Always log tool calls, not just in debug mode logger.info( @@ -271,9 +278,9 @@ def _before_tool_callback( def _after_tool_callback( self, - tool: BaseTool, + tool: "BaseTool", args: dict[str, Any], - tool_context: ToolContext, + tool_context: "ToolContext", tool_response: Optional[dict], ) -> Optional[dict]: # Always log tool responses, not just in debug mode @@ -293,8 +300,8 @@ def _after_tool_callback( def _before_model_callback( self, - callback_context: CallbackContext, - llm_request: LlmRequest, + callback_context: "CallbackContext", + llm_request: "LlmRequest", ) -> None: # Always log LLM requests to see what the judge is being asked logger.info( @@ -308,8 +315,8 @@ def _before_model_callback( def _after_model_callback( self, - callback_context: CallbackContext, - llm_response: LlmResponse, + callback_context: "CallbackContext", + llm_response: "LlmResponse", ) -> None: if not self._debug: return None diff --git a/rogue/evaluator_agent/policy_evaluation.py b/rogue/evaluator_agent/policy_evaluation.py index de51eb2f..15cbc5f4 100644 --- a/rogue/evaluator_agent/policy_evaluation.py +++ b/rogue/evaluator_agent/policy_evaluation.py @@ -1,7 +1,6 @@ import os import re -from litellm import completion from loguru import logger from pydantic import ValidationError from rogue_sdk.types import ChatHistory @@ -142,6 +141,9 @@ def evaluate_policy( expected_outcome: str | None = None, api_key: str | None = None, ) -> PolicyEvaluationResult: + # litellm import takes a while, importing here to reduce startup time. + from litellm import completion + if "/" not in model and model.startswith("gemini"): if os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "false").lower() == "true": model = f"vertex_ai/{model}" diff --git a/rogue/evaluator_agent/run_evaluator_agent.py b/rogue/evaluator_agent/run_evaluator_agent.py index e3ee6194..257864ee 100644 --- a/rogue/evaluator_agent/run_evaluator_agent.py +++ b/rogue/evaluator_agent/run_evaluator_agent.py @@ -1,10 +1,7 @@ import asyncio from asyncio import Queue -from typing import Any, AsyncGenerator +from typing import TYPE_CHECKING, Any, AsyncGenerator -from a2a.types import AgentCapabilities, AgentCard, AgentSkill -from google.adk.runners import Runner -from google.adk.sessions import InMemorySessionService, Session from google.genai import types from httpx import AsyncClient from loguru import logger @@ -13,32 +10,15 @@ from ..common.agent_sessions import create_session from .evaluator_agent import EvaluatorAgent - -def _get_agent_card(host: str, port: int): - skill = AgentSkill( - id="evaluate_agent", - name="Evaluate Agent", - description="Evaluate an agent and provide a report", - tags=["evaluate"], - examples=["evaluate the agent hosted at http://localhost:10001"], - ) - - return AgentCard( - name="Qualifire Agent Evaluator", - description="Evaluates an agent is working as intended and provides a report", - url=f"http://{host}:{port}/", - version="1.0.0", - defaultInputModes=["text"], - defaultOutputModes=["text"], - capabilities=AgentCapabilities(streaming=True), - skills=[skill], - ) +if TYPE_CHECKING: + from google.adk.runners import Runner + from google.adk.sessions import Session async def _run_agent( - agent_runner: Runner, + agent_runner: "Runner", input_text: str, - session: Session, + session: "Session", ) -> str: input_text_preview = ( input_text[:100] + "..." if len(input_text) > 100 else input_text @@ -102,6 +82,10 @@ async def arun_evaluator_agent( business_context: str, deep_test_mode: bool, ) -> AsyncGenerator[tuple[str, Any], None]: + # adk imports take a while, importing them here to reduce rogue startup time. + from google.adk.runners import Runner + from google.adk.sessions import InMemorySessionService + logger.info( "🤖 arun_evaluator_agent starting", extra={ diff --git a/rogue/prompt_injection_evaluator/run_prompt_injection_evaluator.py b/rogue/prompt_injection_evaluator/run_prompt_injection_evaluator.py index a8f25b87..29d53900 100644 --- a/rogue/prompt_injection_evaluator/run_prompt_injection_evaluator.py +++ b/rogue/prompt_injection_evaluator/run_prompt_injection_evaluator.py @@ -2,11 +2,9 @@ from typing import Any, AsyncGenerator, Optional from uuid import uuid4 -import datasets import httpx from a2a.client import A2ACardResolver from a2a.types import Message, MessageSendParams, Part, Role, Task, TextPart -from litellm import completion from loguru import logger from rogue_sdk.types import AuthType, ChatHistory, ChatMessage @@ -79,6 +77,9 @@ async def _judge_injection_attempt( judge_llm: str, judge_llm_api_key: Optional[str], ) -> PromptInjectionEvaluation: + # litellm import takes a while, importing here to reduce startup time. + from litellm import completion + prompt = EVALUATION_PROMPT_TEMPLATE.format( conversation_history=chat_history.model_dump_json(indent=2), payload=payload.payload, @@ -121,8 +122,11 @@ async def arun_prompt_injection_evaluator( dataset_name: str, sample_size: int | None, ) -> AsyncGenerator[tuple[str, Any], None]: + # datasets import takes a while, importing here to reduce startup time. + from datasets import load_dataset + headers = auth_type.get_auth_header(auth_credentials) - dataset_dict = datasets.load_dataset(dataset_name) + dataset_dict = load_dataset(dataset_name) # Pick a split to use. Prioritize 'train', then take the first available. if "train" in dataset_dict: diff --git a/rogue/server/__init__.py b/rogue/server/__init__.py index b1bde018..768c8f2f 100644 --- a/rogue/server/__init__.py +++ b/rogue/server/__init__.py @@ -5,4 +5,4 @@ Provides REST API endpoints and WebSocket support for agent evaluation. """ -from . import api, core, services, websocket +from . import api, core, models, services, websocket diff --git a/rogue/server/api/__init__.py b/rogue/server/api/__init__.py index 0b1b8c2e..e1dec064 100644 --- a/rogue/server/api/__init__.py +++ b/rogue/server/api/__init__.py @@ -2,11 +2,6 @@ API endpoints for the Rogue Agent Evaluator Server. """ -from . import ( - evaluation, - health, - interview, - llm, -) +from . import evaluation, health, interview, llm __all__ = ["evaluation", "health", "interview", "llm"] diff --git a/rogue/server/services/__init__.py b/rogue/server/services/__init__.py index 6b0b3aaf..d9e2be20 100644 --- a/rogue/server/services/__init__.py +++ b/rogue/server/services/__init__.py @@ -4,6 +4,6 @@ evaluation_service, interviewer_service, llm_service, - scenario_evaluation_service, qualifire_service, + scenario_evaluation_service, ) diff --git a/rogue/server/services/interviewer_service.py b/rogue/server/services/interviewer_service.py index d9d9d333..9930ca43 100644 --- a/rogue/server/services/interviewer_service.py +++ b/rogue/server/services/interviewer_service.py @@ -1,7 +1,5 @@ from typing import Any, Dict, Iterator -from litellm import completion - INTERVIEWER_SYSTEM_PROMPT = """ You are an AI interviewer tasked with extracting a business context from a user about their AI agent. Your goal is to gather enough information to later generate test scenarios, @@ -67,6 +65,9 @@ def __init__( ] def send_message(self, user_input: str): + # litellm import takes a while, importing here to reduce startup time. + from litellm import completion + self._messages.append( { "role": "user", diff --git a/rogue/server/services/llm_service.py b/rogue/server/services/llm_service.py index a775edf9..ce49a8fc 100644 --- a/rogue/server/services/llm_service.py +++ b/rogue/server/services/llm_service.py @@ -1,11 +1,14 @@ import json from typing import Optional -from litellm import completion from loguru import logger -from rogue_sdk.types import EvaluationResults, Scenario, Scenarios, ScenarioType -from rogue_sdk.types import StructuredSummary - +from rogue_sdk.types import ( + EvaluationResults, + Scenario, + Scenarios, + ScenarioType, + StructuredSummary, +) SCENARIO_GENERATION_SYSTEM_PROMPT = """ # Test Scenario Designer @@ -170,6 +173,9 @@ def generate_scenarios( context: str, llm_provider_api_key: Optional[str] = None, ) -> Scenarios: + # litellm import takes a while, importing here to reduce startup time. + from litellm import completion + """Generate test scenarios from business context using LLM. Args: @@ -222,6 +228,9 @@ def generate_summary_from_results( results: EvaluationResults, llm_provider_api_key: Optional[str] = None, ) -> StructuredSummary: + # litellm import takes a while, importing here to reduce startup time. + from litellm import completion + system_prompt = SUMMARY_GENERATION_SYSTEM_PROMPT.replace( r"{$EVALUATION_RESULTS}", results.model_dump_json(indent=2), diff --git a/rogue/ui/__init__.py b/rogue/ui/__init__.py index d5aac72e..58431333 100644 --- a/rogue/ui/__init__.py +++ b/rogue/ui/__init__.py @@ -1,4 +1 @@ -from . import app -from . import components -from . import config -from . import models +from . import app, components, config, models diff --git a/rogue/ui/app.py b/rogue/ui/app.py index 2d05689d..f4014234 100644 --- a/rogue/ui/app.py +++ b/rogue/ui/app.py @@ -1,7 +1,6 @@ import json from pathlib import Path -import gradio as gr from rogue_sdk.types import AuthType from ..common.workdir_utils import load_config @@ -13,11 +12,14 @@ ) from .components.scenario_generator import create_scenario_generator_screen from .components.scenario_runner import create_scenario_runner_screen -from .config.theme import theme +from .config.theme import get_theme def get_app(workdir: Path, rogue_server_url: str): - with gr.Blocks(theme=theme, title="Qualifire Agent Evaluator") as app: + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr + + with gr.Blocks(theme=get_theme(), title="Qualifire Agent Evaluator") as app: shared_state = gr.State( { "config": {}, diff --git a/rogue/ui/components/config_screen.py b/rogue/ui/components/config_screen.py index e32e1cb5..e7ff48ca 100644 --- a/rogue/ui/components/config_screen.py +++ b/rogue/ui/components/config_screen.py @@ -1,15 +1,22 @@ -import gradio as gr +from typing import TYPE_CHECKING + from loguru import logger from pydantic import ValidationError from rogue_sdk.types import AgentConfig, AuthType from ...common.workdir_utils import dump_config +if TYPE_CHECKING: + from gradio import State, Tabs + def create_config_screen( - shared_state: gr.State, - tabs_component: gr.Tabs, + shared_state: "State", + tabs_component: "Tabs", ): + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr + config_data = {} if shared_state.value and isinstance(shared_state.value, dict): config_data = shared_state.value.get("config", {}) diff --git a/rogue/ui/components/interviewer.py b/rogue/ui/components/interviewer.py index dac63223..174f9b3a 100644 --- a/rogue/ui/components/interviewer.py +++ b/rogue/ui/components/interviewer.py @@ -1,17 +1,22 @@ import asyncio -from typing import List +from typing import TYPE_CHECKING, List -import gradio as gr from loguru import logger from rogue_sdk import RogueClientConfig, RogueSDK from ...common.workdir_utils import dump_business_context +if TYPE_CHECKING: + from gradio import State, Tabs + def create_interviewer_screen( - shared_state: gr.State, - tabs_component: gr.Tabs, + shared_state: "State", + tabs_component: "Tabs", ): + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr + with gr.Column(): gr.Markdown("## AI-Powered Interviewer") gr.Markdown( diff --git a/rogue/ui/components/report_generator.py b/rogue/ui/components/report_generator.py index db554368..3f502372 100644 --- a/rogue/ui/components/report_generator.py +++ b/rogue/ui/components/report_generator.py @@ -1,12 +1,14 @@ from pathlib import Path -from typing import Tuple +from typing import TYPE_CHECKING, Tuple -import gradio as gr from loguru import logger from rogue_sdk.types import EvaluationResults from ...server.services.api_format_service import convert_with_structured_summary +if TYPE_CHECKING: + from gradio import JSON, Markdown, State + def _load_report_data_from_files( evaluation_results_output_path: Path | None, @@ -29,8 +31,11 @@ def _load_report_data_from_files( def create_report_generator_screen( - shared_state: gr.State, -) -> Tuple[gr.JSON, gr.Markdown]: + shared_state: "State", +) -> Tuple["JSON", "Markdown"]: + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr + with gr.Column(): gr.Markdown("## Summary") summary_display = gr.Markdown( @@ -48,6 +53,9 @@ def setup_report_generator_logic( summary_display, shared_state, ): + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr + def on_report_tab_select(state): results = state.get("results", EvaluationResults()) summary = state.get("summary", "No summary available.") diff --git a/rogue/ui/components/scenario_generator.py b/rogue/ui/components/scenario_generator.py index e6c3ca58..2191080b 100644 --- a/rogue/ui/components/scenario_generator.py +++ b/rogue/ui/components/scenario_generator.py @@ -1,13 +1,19 @@ import asyncio +from typing import TYPE_CHECKING -import gradio as gr from loguru import logger from rogue_sdk import RogueClientConfig, RogueSDK from ...common.workdir_utils import dump_business_context, dump_scenarios +if TYPE_CHECKING: + from gradio import State, Tabs + + +def create_scenario_generator_screen(shared_state: "State", tabs_component: "Tabs"): + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr -def create_scenario_generator_screen(shared_state: gr.State, tabs_component: gr.Tabs): with gr.Column(): gr.Markdown("## Scenario Generation") business_context_display = gr.Textbox( diff --git a/rogue/ui/components/scenario_runner.py b/rogue/ui/components/scenario_runner.py index e46c5540..0d8d5591 100644 --- a/rogue/ui/components/scenario_runner.py +++ b/rogue/ui/components/scenario_runner.py @@ -1,7 +1,7 @@ import asyncio import json +from typing import TYPE_CHECKING -import gradio as gr from loguru import logger from pydantic import HttpUrl from rogue_sdk import RogueClientConfig, RogueSDK @@ -16,6 +16,9 @@ from ...common.workdir_utils import dump_scenarios +if TYPE_CHECKING: + from gradio import State, Tabs + MAX_PARALLEL_RUNS = 10 @@ -41,7 +44,10 @@ def split_into_batches(scenarios: list, n: int) -> list[list]: return batches -def create_scenario_runner_screen(shared_state: gr.State, tabs_component: gr.Tabs): +def create_scenario_runner_screen(shared_state: "State", tabs_component: "Tabs"): + # gradio import takes a while, importing here to reduce startup time. + import gradio as gr + with gr.Column(): gr.Markdown("## Scenario Runner & Evaluator") with gr.Accordion("scenarios to Run"): diff --git a/rogue/ui/config/theme.py b/rogue/ui/config/theme.py index 42ba635c..230a4c53 100644 --- a/rogue/ui/config/theme.py +++ b/rogue/ui/config/theme.py @@ -1,48 +1,56 @@ -import gradio as gr +from typing import TYPE_CHECKING -theme = gr.themes.Soft( - primary_hue=gr.themes.Color( - c50="#ECE9FB", - c100="#ECE9FB", - c200="#ECE9FB", - c300="#6B63BF", - c400="#494199", - c500="#A5183A", - c600="#332E68", - c700="#272350", - c800="#201E44", - c900="#1C1A3D", - c950="#100F24", - ), - secondary_hue=gr.themes.Color( - c50="#ECE9FB", - c100="#ECE9FB", - c200="#ECE9FB", - c300="#6B63BF", - c400="#494199", - c500="#494199", - c600="#332E68", - c700="#272350", - c800="#201E44", - c900="#1C1A3D", - c950="#100F24", - ), - neutral_hue=gr.themes.Color( - c50="#ECE9FB", - c100="#ECE9FB", - c200="#ECE9FB", - c300="#6B63BF", - c400="#494199", - c500="#494199", - c600="#332E68", - c700="#272350", - c800="#201E44", - c900="#1C1A3D", - c950="#100F24", - ), - font=[ - gr.themes.GoogleFont("Mulish"), - "Arial", - "sans-serif", - ], -) +if TYPE_CHECKING: + from gradio.themes import ThemeClass + + +def get_theme() -> "ThemeClass": + # gradio import takes a while, importing here to reduce startup time. + from gradio.themes import Color, GoogleFont, Soft + + return Soft( + primary_hue=Color( + c50="#ECE9FB", + c100="#ECE9FB", + c200="#ECE9FB", + c300="#6B63BF", + c400="#494199", + c500="#A5183A", + c600="#332E68", + c700="#272350", + c800="#201E44", + c900="#1C1A3D", + c950="#100F24", + ), + secondary_hue=Color( + c50="#ECE9FB", + c100="#ECE9FB", + c200="#ECE9FB", + c300="#6B63BF", + c400="#494199", + c500="#494199", + c600="#332E68", + c700="#272350", + c800="#201E44", + c900="#1C1A3D", + c950="#100F24", + ), + neutral_hue=Color( + c50="#ECE9FB", + c100="#ECE9FB", + c200="#ECE9FB", + c300="#6B63BF", + c400="#494199", + c500="#494199", + c600="#332E68", + c700="#272350", + c800="#201E44", + c900="#1C1A3D", + c950="#100F24", + ), + font=[ + GoogleFont("Mulish"), + "Arial", + "sans-serif", + ], + )