Skip to content

Commit a12ae81

Browse files
GWealecopybara-github
authored andcommitted
feat: Add service factory for configurable session and artifact backends
this creates service_factory to handle .adk folder changes (including per-agent .adk defaults and in-memory/custom URI handling) Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 833875524
1 parent 8eb1bdb commit a12ae81

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

src/google/adk/cli/service_registry.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,12 @@ def my_session_factory(uri: str, **kwargs):
6666
import importlib
6767
import logging
6868
import os
69+
from pathlib import Path
6970
import sys
7071
from typing import Any
7172
from typing import Optional
7273
from typing import Protocol
74+
from urllib.parse import unquote
7375
from urllib.parse import urlparse
7476

7577
from ..artifacts.base_artifact_service import BaseArtifactService
@@ -218,6 +220,11 @@ def _register_builtin_services(registry: ServiceRegistry) -> None:
218220
"""Register built-in service implementations."""
219221

220222
# -- Session Services --
223+
def memory_session_factory(uri: str, **kwargs):
224+
from ..sessions.in_memory_session_service import InMemorySessionService
225+
226+
return InMemorySessionService()
227+
221228
def agentengine_session_factory(uri: str, **kwargs):
222229
from ..sessions.vertex_ai_session_service import VertexAiSessionService
223230

@@ -240,19 +247,26 @@ def sqlite_session_factory(uri: str, **kwargs):
240247
parsed = urlparse(uri)
241248
db_path = parsed.path
242249
if not db_path:
243-
return InMemorySessionService()
250+
# Treat sqlite:// without a path as an in-memory session service.
251+
return memory_session_factory("memory://", **kwargs)
244252
elif db_path.startswith("/"):
245253
db_path = db_path[1:]
246254
kwargs_copy = kwargs.copy()
247255
kwargs_copy.pop("agents_dir", None)
248256
return SqliteSessionService(db_path=db_path, **kwargs_copy)
249257

258+
registry.register_session_service("memory", memory_session_factory)
250259
registry.register_session_service("agentengine", agentengine_session_factory)
251260
registry.register_session_service("sqlite", sqlite_session_factory)
252261
for scheme in ["postgresql", "mysql"]:
253262
registry.register_session_service(scheme, database_session_factory)
254263

255264
# -- Artifact Services --
265+
def memory_artifact_factory(uri: str, **kwargs):
266+
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
267+
268+
return InMemoryArtifactService()
269+
256270
def gcs_artifact_factory(uri: str, **kwargs):
257271
from ..artifacts.gcs_artifact_service import GcsArtifactService
258272

@@ -262,15 +276,35 @@ def gcs_artifact_factory(uri: str, **kwargs):
262276
bucket_name = parsed_uri.netloc
263277
return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy)
264278

279+
def file_artifact_factory(uri: str, **kwargs):
280+
from ..artifacts.file_artifact_service import FileArtifactService
281+
282+
per_agent = kwargs.get("per_agent", False)
283+
if per_agent:
284+
raise ValueError(
285+
"file:// artifact URIs are not supported in multi-agent mode."
286+
)
287+
parsed_uri = urlparse(uri)
288+
if parsed_uri.netloc not in ("", "localhost"):
289+
raise ValueError(
290+
"file:// artifact URIs must reference the local filesystem."
291+
)
292+
if not parsed_uri.path:
293+
raise ValueError("file:// artifact URIs must include a path component.")
294+
artifact_path = Path(unquote(parsed_uri.path))
295+
return FileArtifactService(root_dir=artifact_path)
296+
297+
registry.register_artifact_service("memory", memory_artifact_factory)
265298
registry.register_artifact_service("gs", gcs_artifact_factory)
299+
registry.register_artifact_service("file", file_artifact_factory)
266300

267301
# -- Memory Services --
268302
def rag_memory_factory(uri: str, **kwargs):
269303
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
270304

271305
rag_corpus = urlparse(uri).netloc
272306
if not rag_corpus:
273-
raise ValueError("Rag corpus cannot be empty.")
307+
raise ValueError("Rag corpus can not be empty.")
274308
agents_dir = kwargs.get("agents_dir")
275309
project, location = _load_gcp_config(agents_dir, "RAG memory service")
276310
return VertexAiRagMemoryService(

0 commit comments

Comments
 (0)