Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions examples/tshirt_store_agent/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,8 @@
from a2a.server.apps import A2AStarletteApplication
from a2a.server.request_handlers import DefaultRequestHandler
from a2a.server.tasks import InMemoryTaskStore
from a2a.types import (
AgentCapabilities,
AgentCard,
AgentSkill,
)
from a2a.types import AgentCapabilities, AgentCard, AgentSkill
from dotenv import load_dotenv
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService

from .tshirt_store_agent import create_tshirt_store_agent
from .tshirt_store_agent_executor import TShirtStoreAgentExecutor
Expand All @@ -28,6 +20,12 @@
@click.option("--host", "host", default="localhost")
@click.option("--port", "port", default=10001)
def main(host: str, port: int) -> None:
# adk imports take a while, importing them here to reduce rogue startup time.
from google.adk.artifacts import InMemoryArtifactService
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService

skill = AgentSkill(
id="sell_tshirt",
name="Sell T-Shirt",
Expand Down
14 changes: 10 additions & 4 deletions examples/tshirt_store_agent/tshirt_store_agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.adk.agents import LlmAgent

from google.adk.agents import LlmAgent
from google.adk.models.lite_llm import LiteLlm
from google.adk.tools import FunctionTool

AGENT_INSTRUCTIONS = """
You are an agent for a t-shirt store named Shirtify.
Expand Down Expand Up @@ -72,7 +73,12 @@ def send_email_tool(
return f"Email sent to {email} with subject {subject} and body {body}"


def create_tshirt_store_agent() -> LlmAgent:
def create_tshirt_store_agent() -> "LlmAgent":
# adk imports take a while, importing them here to reduce rogue startup time.
from google.adk.agents import LlmAgent
from google.adk.models.lite_llm import LiteLlm
from google.adk.tools import FunctionTool

tools: list[FunctionTool] = [
FunctionTool(
func=inventory_tool,
Expand Down
25 changes: 14 additions & 11 deletions examples/tshirt_store_agent/tshirt_store_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,45 @@
import base64
from logging import getLogger
from typing import AsyncGenerator
from typing import TYPE_CHECKING, AsyncGenerator

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import (
AgentCard,
TaskState,
UnsupportedOperationError,
Part,
TextPart,
FilePart,
FileWithUri,
FileWithBytes,
FileWithUri,
Part,
TaskState,
TextPart,
UnsupportedOperationError,
)
from a2a.utils.errors import ServerError
from google.adk import Runner
from google.adk.events import Event
from google.genai import types

if TYPE_CHECKING:
from google.adk import Runner
from google.adk.events import Event


logger = getLogger(__name__)


class TShirtStoreAgentExecutor(AgentExecutor):
"""An AgentExecutor that runs an ADK-based Agent."""

def __init__(self, runner: Runner, card: AgentCard):
def __init__(self, runner: "Runner", card: AgentCard):
self.runner = runner
self._card = card

self._running_sessions = {} # type: ignore

def _run_agent(
self,
session_id,
session_id: str,
new_message: types.Content,
) -> AsyncGenerator[Event, None]:
) -> AsyncGenerator["Event", None]:
return self.runner.run_async(
session_id=session_id,
user_id="self",
Expand Down
2 changes: 2 additions & 0 deletions rogue/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from . import (
agent_model_wrapper,
agent_sessions,
generic_agent_executor,
generic_task_callback,
logging,
remote_agent_connection,
tui_installer,
update_checker,
workdir_utils,
)
13 changes: 9 additions & 4 deletions rogue/common/agent_model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from functools import lru_cache
from typing import Optional
from typing import TYPE_CHECKING, Optional

from google.adk.models import LLMRegistry, BaseLlm
from google.adk.models.lite_llm import LiteLlm
from loguru import logger

if TYPE_CHECKING:
from google.adk.models import BaseLlm


@lru_cache()
def get_llm_from_model(
model: str,
llm_auth: Optional[str] = None,
) -> BaseLlm:
) -> "BaseLlm":
# adk imports take a while, importing them here to reduce rogue startup time.
from google.adk.models import LLMRegistry
from google.adk.models.lite_llm import LiteLlm

try:
llm_cls = LLMRegistry.resolve(model)
except ValueError:
Expand Down
8 changes: 5 additions & 3 deletions rogue/common/agent_sessions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import TYPE_CHECKING
from uuid import uuid4

from google.adk.sessions import Session, BaseSessionService
if TYPE_CHECKING:
from google.adk.sessions import BaseSessionService, Session


async def create_session(
app_name: str,
session_service: BaseSessionService,
session_service: "BaseSessionService",
user_id: str | None = None,
) -> Session:
) -> "Session":
user_id = user_id or uuid4().hex
return await session_service.create_session(
app_name=app_name,
Expand Down
71 changes: 42 additions & 29 deletions rogue/common/generic_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,44 @@
import base64
from typing import AsyncGenerator
from typing import TYPE_CHECKING, AsyncGenerator

from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.agent_execution import AgentExecutor
from a2a.server.tasks import TaskUpdater
from a2a.types import (
AgentCard,
TaskState,
UnsupportedOperationError,
Part,
TextPart,
FilePart,
FileWithUri,
FileWithBytes,
FileWithUri,
Part,
TaskState,
TextPart,
UnsupportedOperationError,
)
from a2a.utils.errors import ServerError
from google.adk import Runner
from google.adk.events import Event
from google.genai import types
from loguru import logger

if TYPE_CHECKING:
from a2a.server.agent_execution import RequestContext
from a2a.server.events import EventQueue
from google.adk import Runner
from google.adk.events import Event
from google.genai.types import Content
from google.genai.types import Part as GenAIPart


class GenericAgentExecutor(AgentExecutor):
"""An AgentExecutor that runs an ADK-based Agent."""

def __init__(self, runner: Runner, card: AgentCard):
def __init__(self, runner: "Runner", card: AgentCard):
self.runner = runner
self._card = card

self._running_sessions = {} # type: ignore

def _run_agent(
self,
session_id,
new_message: types.Content,
) -> AsyncGenerator[Event, None]:
session_id: str,
new_message: "Content",
) -> AsyncGenerator["Event", None]:
return self.runner.run_async(
session_id=session_id,
user_id="self",
Expand All @@ -43,7 +47,7 @@ def _run_agent(

async def _process_request(
self,
new_message: types.Content,
new_message: "Content",
session_id: str,
task_updater: TaskUpdater,
) -> None:
Expand Down Expand Up @@ -79,9 +83,13 @@ async def _process_request(

async def execute(
self,
context: RequestContext,
event_queue: EventQueue,
context: "RequestContext",
event_queue: "EventQueue",
):
# google.genai imports take a while,
# importing them here to reduce rogue startup time.
from google.genai.types import UserContent

# Run the agent until either complete or the task is suspended.
updater = TaskUpdater(
event_queue,
Expand All @@ -95,15 +103,15 @@ async def execute(

if context.message is not None:
await self._process_request(
types.UserContent(
UserContent(
parts=convert_a2a_parts_to_genai(context.message.parts),
),
context.context_id or "",
updater,
)
logger.debug("EvaluatorAgentExecutor execute exiting")

async def cancel(self, context: RequestContext, event_queue: EventQueue):
async def cancel(self, context: "RequestContext", event_queue: "EventQueue"):
# Ideally: kill any ongoing tasks.
raise ServerError(error=UnsupportedOperationError())

Expand Down Expand Up @@ -136,27 +144,32 @@ async def _upsert_session(self, session_id: str):
return session


def convert_a2a_parts_to_genai(parts: list[Part]) -> list[types.Part]:
def convert_a2a_parts_to_genai(parts: list[Part]) -> list["GenAIPart"]:
"""Convert a list of A2A Part types into a list of Google Gen AI Part types."""
return [convert_a2a_part_to_genai(part) for part in parts]


def convert_a2a_part_to_genai(part: Part) -> types.Part:
def convert_a2a_part_to_genai(part: Part) -> "GenAIPart":
# google.genai imports take a while,
# importing them here to reduce rogue startup time.
from google.genai.types import Blob, FileData
from google.genai.types import Part as GenAIPart

"""Convert a single A2A Part type into a Google Gen AI Part type."""
part = part.root # type: ignore
if isinstance(part, TextPart):
return types.Part(text=part.text)
return GenAIPart(text=part.text)
if isinstance(part, FilePart):
if isinstance(part.file, FileWithUri):
return types.Part(
file_data=types.FileData(
return GenAIPart(
file_data=FileData(
file_uri=part.file.uri,
mime_type=part.file.mimeType,
),
)
if isinstance(part.file, FileWithBytes):
return types.Part(
inline_data=types.Blob(
return GenAIPart(
inline_data=Blob(
data=base64.b64decode(part.file.bytes),
mime_type=part.file.mimeType,
),
Expand All @@ -165,7 +178,7 @@ def convert_a2a_part_to_genai(part: Part) -> types.Part:
raise ValueError(f"Unsupported part type: {type(part)}")


def convert_genai_parts_to_a2a(parts: list[types.Part] | None) -> list[Part]:
def convert_genai_parts_to_a2a(parts: list["GenAIPart"] | None) -> list[Part]:
"""Convert a list of Google Gen AI Part types into a list of A2A Part types."""
parts = parts or []
return [
Expand All @@ -175,7 +188,7 @@ def convert_genai_parts_to_a2a(parts: list[types.Part] | None) -> list[Part]:
]


def convert_genai_part_to_a2a(part: types.Part) -> Part:
def convert_genai_part_to_a2a(part: "GenAIPart") -> Part:
"""Convert a single Google Gen AI Part type into an A2A Part type."""
if part.text:
return Part(root=TextPart(text=part.text))
Expand Down
3 changes: 1 addition & 2 deletions rogue/evaluator_agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from . import evaluator_agent
from . import run_evaluator_agent
from . import evaluator_agent, policy_evaluation, run_evaluator_agent
Loading
Loading