Skip to content

Commit 5a543c0

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: implement new methods in in-memory artifact service
* save_artifact with custom_metadata * list_artifact_versions * get_artifact_version PiperOrigin-RevId: 820321444
1 parent 1e1d63f commit 5a543c0

File tree

3 files changed

+259
-18
lines changed

3 files changed

+259
-18
lines changed

src/google/adk/artifacts/in_memory_artifact_service.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import dataclasses
1617
import logging
1718
from typing import Any
1819
from typing import Optional
@@ -28,14 +29,27 @@
2829
logger = logging.getLogger("google_adk." + __name__)
2930

3031

32+
@dataclasses.dataclass
33+
class _ArtifactEntry:
34+
"""Represents a single version of an artifact stored in memory.
35+
36+
Attributes:
37+
data: The actual data of the artifact.
38+
artifact_version: Metadata about this specific version of the artifact.
39+
"""
40+
41+
data: types.Part
42+
artifact_version: ArtifactVersion
43+
44+
3145
class InMemoryArtifactService(BaseArtifactService, BaseModel):
3246
"""An in-memory implementation of the artifact service.
3347
3448
It is not suitable for multi-threaded production environments. Use it for
3549
testing and development only.
3650
"""
3751

38-
artifacts: dict[str, list[types.Part]] = Field(default_factory=dict)
52+
artifacts: dict[str, list[_ArtifactEntry]] = Field(default_factory=dict)
3953

4054
def _file_has_user_namespace(self, filename: str) -> bool:
4155
"""Checks if the filename has a user namespace.
@@ -87,15 +101,34 @@ async def save_artifact(
87101
session_id: Optional[str] = None,
88102
custom_metadata: Optional[dict[str, Any]] = None,
89103
) -> int:
90-
# TODO: b/447451270 - Support saving artifact with custom metadata.
91-
if custom_metadata:
92-
raise NotImplementedError("custom_metadata is not supported yet.")
93-
94104
path = self._artifact_path(app_name, user_id, filename, session_id)
95105
if path not in self.artifacts:
96106
self.artifacts[path] = []
97107
version = len(self.artifacts[path])
98-
self.artifacts[path].append(artifact)
108+
if self._file_has_user_namespace(filename):
109+
canonical_uri = f"memory://apps/{app_name}/users/{user_id}/artifacts/{filename}/versions/{version}"
110+
else:
111+
canonical_uri = f"memory://apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{filename}/versions/{version}"
112+
113+
artifact_version = ArtifactVersion(
114+
version=version,
115+
canonical_uri=canonical_uri,
116+
)
117+
if custom_metadata:
118+
artifact_version.custom_metadata = custom_metadata
119+
120+
if artifact.inline_data is not None:
121+
artifact_version.mime_type = artifact.inline_data.mime_type
122+
elif artifact.text is not None:
123+
artifact_version.mime_type = "text/plain"
124+
elif artifact.file_data is not None:
125+
artifact_version.mime_type = artifact.file_data.mime_type
126+
else:
127+
raise ValueError("Not supported artifact type.")
128+
129+
self.artifacts[path].append(
130+
_ArtifactEntry(data=artifact, artifact_version=artifact_version)
131+
)
99132
return version
100133

101134
@override
@@ -114,7 +147,10 @@ async def load_artifact(
114147
return None
115148
if version is None:
116149
version = -1
117-
return versions[version]
150+
try:
151+
return versions[version].data
152+
except IndexError:
153+
return None
118154

119155
@override
120156
async def list_artifact_keys(
@@ -172,8 +208,11 @@ async def list_artifact_versions(
172208
filename: str,
173209
session_id: Optional[str] = None,
174210
) -> list[ArtifactVersion]:
175-
# TODO: b/447451270 - Support list_artifact_versions.
176-
raise NotImplementedError("list_artifact_versions is not implemented yet.")
211+
path = self._artifact_path(app_name, user_id, filename, session_id)
212+
entries = self.artifacts.get(path)
213+
if not entries:
214+
return []
215+
return [entry.artifact_version for entry in entries]
177216

178217
@override
179218
async def get_artifact_version(
@@ -185,5 +224,14 @@ async def get_artifact_version(
185224
session_id: Optional[str] = None,
186225
version: Optional[int] = None,
187226
) -> Optional[ArtifactVersion]:
188-
# TODO: b/447451270 - Support get_artifact_version.
189-
raise NotImplementedError("get_artifact_version is not implemented yet.")
227+
path = self._artifact_path(app_name, user_id, filename, session_id)
228+
entries = self.artifacts.get(path)
229+
if not entries:
230+
return None
231+
232+
if version is None:
233+
version = -1
234+
try:
235+
return entries[version].artifact_version
236+
except IndexError:
237+
return None

tests/unittests/artifacts/test_artifact_service.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,24 @@
1414

1515
"""Tests for the artifact service."""
1616

17+
from datetime import datetime
1718
import enum
1819
from typing import Optional
1920
from typing import Union
2021
from unittest import mock
22+
from unittest.mock import patch
2123

24+
from google.adk.artifacts.base_artifact_service import ArtifactVersion
2225
from google.adk.artifacts.gcs_artifact_service import GcsArtifactService
2326
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
2427
from google.genai import types
2528
import pytest
2629

2730
Enum = enum.Enum
2831

32+
# Define a fixed datetime object to be returned by datetime.now()
33+
FIXED_DATETIME = datetime(2025, 1, 1, 12, 0, 0)
34+
2935

3036
class ArtifactServiceType(Enum):
3137
IN_MEMORY = "IN_MEMORY"
@@ -195,6 +201,15 @@ async def test_save_load_delete(service_type):
195201
== artifact
196202
)
197203

204+
# Attempt to load a version that doesn't exist
205+
assert not await artifact_service.load_artifact(
206+
app_name=app_name,
207+
user_id=user_id,
208+
session_id=session_id,
209+
filename=filename,
210+
version=3,
211+
)
212+
198213
await artifact_service.delete_artifact(
199214
app_name=app_name,
200215
user_id=user_id,
@@ -322,3 +337,171 @@ async def test_list_keys_preserves_user_prefix():
322337
# Should contain prefixed names and session file
323338
expected_keys = ["user:document.pdf", "user:image.png", "session_file.txt"]
324339
assert sorted(artifact_keys) == sorted(expected_keys)
340+
341+
342+
@pytest.mark.asyncio
343+
async def test_list_artifact_versions_and_get_artifact_version():
344+
"""Tests listing artifact versions and getting a specific version."""
345+
artifact_service = InMemoryArtifactService()
346+
app_name = "app0"
347+
user_id = "user0"
348+
session_id = "123"
349+
filename = "filename"
350+
versions = [
351+
types.Part.from_bytes(
352+
data=i.to_bytes(2, byteorder="big"), mime_type="text/plain"
353+
)
354+
for i in range(4)
355+
]
356+
357+
with patch(
358+
"google.adk.artifacts.base_artifact_service.datetime"
359+
) as mock_datetime:
360+
mock_datetime.now.return_value = FIXED_DATETIME
361+
362+
for i in range(4):
363+
await artifact_service.save_artifact(
364+
app_name=app_name,
365+
user_id=user_id,
366+
session_id=session_id,
367+
filename=filename,
368+
artifact=versions[i],
369+
custom_metadata={"key": "value" + str(i)},
370+
)
371+
372+
artifact_versions = await artifact_service.list_artifact_versions(
373+
app_name=app_name,
374+
user_id=user_id,
375+
session_id=session_id,
376+
filename=filename,
377+
)
378+
379+
expected_artifact_versions = [
380+
ArtifactVersion(
381+
version=i,
382+
canonical_uri=(
383+
f"memory://apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{filename}/versions/{i}"
384+
),
385+
custom_metadata={"key": "value" + str(i)},
386+
mime_type="text/plain",
387+
create_time=FIXED_DATETIME.timestamp(),
388+
)
389+
for i in range(4)
390+
]
391+
assert artifact_versions == expected_artifact_versions
392+
393+
# Get latest artifact version when version is not specified
394+
assert (
395+
await artifact_service.get_artifact_version(
396+
app_name=app_name,
397+
user_id=user_id,
398+
session_id=session_id,
399+
filename=filename,
400+
)
401+
== expected_artifact_versions[-1]
402+
)
403+
404+
# Get artifact version by version number
405+
assert (
406+
await artifact_service.get_artifact_version(
407+
app_name=app_name,
408+
user_id=user_id,
409+
session_id=session_id,
410+
filename=filename,
411+
version=2,
412+
)
413+
== expected_artifact_versions[2]
414+
)
415+
416+
417+
@pytest.mark.asyncio
418+
async def test_list_artifact_versions_with_user_prefix():
419+
"""Tests listing artifact versions with user prefix."""
420+
artifact_service = InMemoryArtifactService()
421+
app_name = "app0"
422+
user_id = "user0"
423+
session_id = "123"
424+
user_scoped_filename = "user:document.pdf"
425+
versions = [
426+
types.Part.from_bytes(
427+
data=i.to_bytes(2, byteorder="big"), mime_type="text/plain"
428+
)
429+
for i in range(4)
430+
]
431+
432+
with patch(
433+
"google.adk.artifacts.base_artifact_service.datetime"
434+
) as mock_datetime:
435+
mock_datetime.now.return_value = FIXED_DATETIME
436+
437+
for i in range(4):
438+
# Save artifacts with "user:" prefix (cross-session artifacts)
439+
await artifact_service.save_artifact(
440+
app_name=app_name,
441+
user_id=user_id,
442+
session_id=session_id,
443+
filename=user_scoped_filename,
444+
artifact=versions[i],
445+
custom_metadata={"key": "value" + str(i)},
446+
)
447+
448+
artifact_versions = await artifact_service.list_artifact_versions(
449+
app_name=app_name,
450+
user_id=user_id,
451+
session_id=session_id,
452+
filename=user_scoped_filename,
453+
)
454+
455+
expected_artifact_versions = [
456+
ArtifactVersion(
457+
version=i,
458+
canonical_uri=(
459+
f"memory://apps/{app_name}/users/{user_id}/artifacts/{user_scoped_filename}/versions/{i}"
460+
),
461+
custom_metadata={"key": "value" + str(i)},
462+
mime_type="text/plain",
463+
create_time=FIXED_DATETIME.timestamp(),
464+
)
465+
for i in range(4)
466+
]
467+
assert artifact_versions == expected_artifact_versions
468+
469+
470+
@pytest.mark.asyncio
471+
async def test_get_artifact_version_artifact_does_not_exist():
472+
"""Tests getting an artifact version when artifact does not exist."""
473+
artifact_service = InMemoryArtifactService()
474+
assert not await artifact_service.get_artifact_version(
475+
app_name="test_app",
476+
user_id="test_user",
477+
session_id="session_id",
478+
filename="filename",
479+
)
480+
481+
482+
@pytest.mark.asyncio
483+
async def test_get_artifact_version_out_of_index():
484+
"""Tests loading an artifact with an out-of-index version."""
485+
artifact_service = InMemoryArtifactService()
486+
app_name = "app0"
487+
user_id = "user0"
488+
session_id = "123"
489+
filename = "filename"
490+
artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain")
491+
492+
await artifact_service.save_artifact(
493+
app_name=app_name,
494+
user_id=user_id,
495+
session_id=session_id,
496+
filename=filename,
497+
artifact=artifact,
498+
)
499+
500+
# Attempt to get a version that doesn't exist
501+
assert not await artifact_service.get_artifact_version(
502+
app_name=app_name,
503+
user_id=user_id,
504+
session_id=session_id,
505+
filename=filename,
506+
version=3,
507+
)

tests/unittests/tools/test_agent_tool.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def test_update_state():
178178
assert runner.session.state['state_1'] == 'changed_value'
179179

180180

181-
def test_update_artifacts():
181+
@mark.asyncio
182+
async def test_update_artifacts():
182183
"""The agent tool can read and write artifacts."""
183184

184185
async def before_tool_agent(callback_context: CallbackContext):
@@ -219,12 +220,21 @@ async def after_main_agent(callback_context: CallbackContext):
219220
runner = testing_utils.InMemoryRunner(root_agent)
220221
runner.run('test1')
221222

222-
artifacts_path = f'test_app/test_user/{runner.session_id}'
223-
assert runner.runner.artifact_service.artifacts == {
224-
f'{artifacts_path}/artifact_1': [Part.from_text(text='test')],
225-
f'{artifacts_path}/artifact_2': [Part.from_text(text='test 2')],
226-
f'{artifacts_path}/artifact_3': [Part.from_text(text='test 2 3')],
227-
}
223+
async def load_artifact(filename: str):
224+
return await runner.runner.artifact_service.load_artifact(
225+
app_name='test_app',
226+
user_id='test_user',
227+
session_id=runner.session_id,
228+
filename=filename,
229+
)
230+
231+
assert await runner.runner.artifact_service.list_artifact_keys(
232+
app_name='test_app', user_id='test_user', session_id=runner.session_id
233+
) == ['artifact_1', 'artifact_2', 'artifact_3']
234+
235+
assert await load_artifact('artifact_1') == Part.from_text(text='test')
236+
assert await load_artifact('artifact_2') == Part.from_text(text='test 2')
237+
assert await load_artifact('artifact_3') == Part.from_text(text='test 2 3')
228238

229239

230240
@mark.parametrize(

0 commit comments

Comments
 (0)