Skip to content

Commit e212ff5

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add new methods in the artifact service interface
PiperOrigin-RevId: 818473733
1 parent e63180c commit e212ff5

File tree

4 files changed

+163
-1
lines changed

4 files changed

+163
-1
lines changed

src/google/adk/artifacts/base_artifact_service.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,23 @@
1515

1616
from abc import ABC
1717
from abc import abstractmethod
18+
from typing import Any
1819
from typing import Optional
1920

2021
from google.genai import types
22+
from pydantic import BaseModel
23+
from pydantic import Field
24+
25+
26+
class ArtifactVersion(BaseModel):
27+
"""Represents the metadata of a specific version of an artifact."""
28+
29+
version: int
30+
"""The version number of the artifact."""
31+
canonical_uri: str
32+
"""The canonical URI of the artifact version."""
33+
custom_metadata: dict[str, Any] = Field(default_factory=dict)
34+
"""A dictionary of custom metadata associated with the artifact version."""
2135

2236

2337
class BaseArtifactService(ABC):
@@ -32,6 +46,7 @@ async def save_artifact(
3246
filename: str,
3347
artifact: types.Part,
3448
session_id: Optional[str] = None,
49+
custom_metadata: Optional[dict[str, Any]] = None,
3550
) -> int:
3651
"""Saves an artifact to the artifact service storage.
3752
@@ -43,8 +58,12 @@ async def save_artifact(
4358
app_name: The app name.
4459
user_id: The user ID.
4560
filename: The filename of the artifact.
46-
artifact: The artifact to save.
61+
artifact: The artifact to save. If the artifact consists of `file_data`,
62+
the artifact service assumes its content has been uploaded separately,
63+
and this method will associate the `file_data` with the artifact if
64+
necessary.
4765
session_id: The session ID. If `None`, the artifact is user-scoped.
66+
custom_metadata: custom metadata to associate with the artifact.
4867
4968
Returns:
5069
The revision ID. The first version of the artifact has a revision ID of 0.
@@ -136,3 +155,54 @@ async def list_versions(
136155
Returns:
137156
A list of all available versions of the artifact.
138157
"""
158+
159+
@abstractmethod
160+
async def list_artifact_versions(
161+
self,
162+
*,
163+
app_name: str,
164+
user_id: str,
165+
filename: str,
166+
session_id: Optional[str] = None,
167+
) -> list[ArtifactVersion]:
168+
"""Lists all versions and their metadata for a specific artifact.
169+
170+
Args:
171+
app_name: The name of the application.
172+
user_id: The ID of the user.
173+
filename: The name of the artifact file.
174+
session_id: The ID of the session. If `None`, lists versions of the
175+
user-scoped artifact. Otherwise, lists versions of the artifact within
176+
the specified session.
177+
178+
Returns:
179+
A list of ArtifactVersion objects, each representing a version of the
180+
artifact and its associated metadata.
181+
"""
182+
183+
@abstractmethod
184+
async def get_artifact_version(
185+
self,
186+
*,
187+
app_name: str,
188+
user_id: str,
189+
filename: str,
190+
session_id: Optional[str] = None,
191+
version: Optional[int] = None,
192+
) -> Optional[ArtifactVersion]:
193+
"""Gets the metadata for a specific version of an artifact.
194+
195+
Args:
196+
app_name: The name of the application.
197+
user_id: The ID of the user.
198+
filename: The name of the artifact file.
199+
session_id: The ID of the session. If `None`, the artifact will be fetched
200+
from the user-scoped artifacts. Otherwise, it will be fetched from the
201+
specified session.
202+
version: The version number of the artifact to retrieve. If `None`, the
203+
latest version will be returned.
204+
205+
Returns:
206+
An ArtifactVersion object containing the metadata of the specified
207+
artifact version, or `None` if the artifact version is not found.
208+
"""

src/google/adk/artifacts/gcs_artifact_service.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424

2525
import asyncio
2626
import logging
27+
from typing import Any
2728
from typing import Optional
2829

2930
from google.cloud import storage
3031
from google.genai import types
3132
from typing_extensions import override
3233

34+
from .base_artifact_service import ArtifactVersion
3335
from .base_artifact_service import BaseArtifactService
3436

3537
logger = logging.getLogger("google_adk." + __name__)
@@ -58,6 +60,7 @@ async def save_artifact(
5860
filename: str,
5961
artifact: types.Part,
6062
session_id: Optional[str] = None,
63+
custom_metadata: Optional[dict[str, Any]] = None,
6164
) -> int:
6265
return await asyncio.to_thread(
6366
self._save_artifact,
@@ -66,6 +69,7 @@ async def save_artifact(
6669
session_id,
6770
filename,
6871
artifact,
72+
custom_metadata,
6973
)
7074

7175
@override
@@ -180,7 +184,12 @@ def _save_artifact(
180184
session_id: Optional[str],
181185
filename: str,
182186
artifact: types.Part,
187+
custom_metadata: Optional[dict[str, Any]] = None,
183188
) -> int:
189+
if custom_metadata:
190+
# TODO: b/447451270 - support saving artifact with custom metadata.
191+
raise NotImplementedError("custom_metadata is not supported yet.")
192+
184193
versions = self._list_versions(
185194
app_name=app_name,
186195
user_id=user_id,
@@ -316,3 +325,28 @@ def _list_versions(
316325
*_, version = blob.name.split("/")
317326
versions.append(int(version))
318327
return versions
328+
329+
@override
330+
async def list_artifact_versions(
331+
self,
332+
*,
333+
app_name: str,
334+
user_id: str,
335+
filename: str,
336+
session_id: Optional[str] = None,
337+
) -> list[ArtifactVersion]:
338+
# TODO: b/447451270 - Support list_artifact_versions.
339+
raise NotImplementedError("list_artifact_versions is not implemented yet.")
340+
341+
@override
342+
async def get_artifact_version(
343+
self,
344+
*,
345+
app_name: str,
346+
user_id: str,
347+
filename: str,
348+
session_id: Optional[str] = None,
349+
version: Optional[int] = None,
350+
) -> Optional[ArtifactVersion]:
351+
# TODO: b/447451270 - Support get_artifact_version.
352+
raise NotImplementedError("get_artifact_version is not implemented yet.")

src/google/adk/artifacts/in_memory_artifact_service.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
from __future__ import annotations
1515

1616
import logging
17+
from typing import Any
1718
from typing import Optional
1819

1920
from google.genai import types
2021
from pydantic import BaseModel
2122
from pydantic import Field
2223
from typing_extensions import override
2324

25+
from .base_artifact_service import ArtifactVersion
2426
from .base_artifact_service import BaseArtifactService
2527

2628
logger = logging.getLogger("google_adk." + __name__)
@@ -83,7 +85,12 @@ async def save_artifact(
8385
filename: str,
8486
artifact: types.Part,
8587
session_id: Optional[str] = None,
88+
custom_metadata: Optional[dict[str, Any]] = None,
8689
) -> int:
90+
# TODO: b/447451270 - Support saving artifact with custom metadata.
91+
if custom_metadata:
92+
raise NotImplementedError("custom_metadata is not supported yet.")
93+
8794
path = self._artifact_path(app_name, user_id, filename, session_id)
8895
if path not in self.artifacts:
8996
self.artifacts[path] = []
@@ -155,3 +162,28 @@ async def list_versions(
155162
if not versions:
156163
return []
157164
return list(range(len(versions)))
165+
166+
@override
167+
async def list_artifact_versions(
168+
self,
169+
*,
170+
app_name: str,
171+
user_id: str,
172+
filename: str,
173+
session_id: Optional[str] = None,
174+
) -> list[ArtifactVersion]:
175+
# TODO: b/447451270 - Support list_artifact_versions.
176+
raise NotImplementedError("list_artifact_versions is not implemented yet.")
177+
178+
@override
179+
async def get_artifact_version(
180+
self,
181+
*,
182+
app_name: str,
183+
user_id: str,
184+
filename: str,
185+
session_id: Optional[str] = None,
186+
version: Optional[int] = None,
187+
) -> Optional[ArtifactVersion]:
188+
# TODO: b/447451270 - Support get_artifact_version.
189+
raise NotImplementedError("get_artifact_version is not implemented yet.")

src/google/adk/tools/_forwarding_artifact_service.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Any
1718
from typing import Optional
1819
from typing import TYPE_CHECKING
1920

2021
from google.genai import types
2122
from typing_extensions import override
2223

24+
from ..artifacts.base_artifact_service import ArtifactVersion
2325
from ..artifacts.base_artifact_service import BaseArtifactService
2426

2527
if TYPE_CHECKING:
@@ -42,6 +44,7 @@ async def save_artifact(
4244
filename: str,
4345
artifact: types.Part,
4446
session_id: Optional[str] = None,
47+
custom_metadata: Optional[dict[str, Any]] = None,
4548
) -> int:
4649
return await self.tool_context.save_artifact(
4750
filename=filename, artifact=artifact
@@ -104,3 +107,26 @@ async def list_versions(
104107
session_id=self._invocation_context.session.id,
105108
filename=filename,
106109
)
110+
111+
@override
112+
async def list_artifact_versions(
113+
self,
114+
*,
115+
app_name: str,
116+
user_id: str,
117+
filename: str,
118+
session_id: Optional[str] = None,
119+
) -> list[ArtifactVersion]:
120+
raise NotImplementedError("list_artifact_versions is not implemented yet.")
121+
122+
@override
123+
async def get_artifact_version(
124+
self,
125+
*,
126+
app_name: str,
127+
user_id: str,
128+
filename: str,
129+
session_id: Optional[str] = None,
130+
version: Optional[int] = None,
131+
) -> Optional[ArtifactVersion]:
132+
raise NotImplementedError("get_artifact_version is not implemented yet.")

0 commit comments

Comments
 (0)