From 6be874a7c6c43b7acbf5926e38e56d2ab367f5a1 Mon Sep 17 00:00:00 2001 From: Zhenyi Qi Date: Mon, 17 Jun 2024 14:51:45 -0700 Subject: [PATCH] feat: GenAI - Context Caching - add get() classmethod and refresh() instance method PiperOrigin-RevId: 644141561 --- tests/unit/vertexai/test_caching.py | 45 +++++++++++++++++++++++++++++ vertexai/caching/_caching.py | 40 +++++++++++++------------ 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/tests/unit/vertexai/test_caching.py b/tests/unit/vertexai/test_caching.py index d3906865be..e98eb457b3 100644 --- a/tests/unit/vertexai/test_caching.py +++ b/tests/unit/vertexai/test_caching.py @@ -81,6 +81,32 @@ def get_cached_content(self, name, retry=None): yield get_cached_content +@pytest.fixture +def mock_list_cached_contents(): + """Mocks GenAiCacheServiceClient.get_cached_content().""" + + def list_cached_contents(self, request): + del self, request + response = [ + GapicCachedContent( + name="cached_content1_from_list_request", + model="model-name1", + ), + GapicCachedContent( + name="cached_content2_from_list_request", + model="model-name2", + ), + ] + return response + + with mock.patch.object( + gen_ai_cache_service.client.GenAiCacheServiceClient, + "list_cached_contents", + new=list_cached_contents, + ) as list_cached_contents: + yield list_cached_contents + + @pytest.mark.usefixtures("google_auth_mock") class TestCaching: """Unit tests for caching.CachedContent.""" @@ -118,6 +144,19 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content): ) assert cache.model_name == "model-name" + def test_get_with_content_id(self, mock_get_cached_content): + partial_resource_name = "contents-id" + + cache = caching.CachedContent.get( + cached_content_name=partial_resource_name, + ) + + assert cache.name == "contents-id" + assert cache.resource_name == ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/cachedContents/contents-id" + ) + assert cache.model_name == "model-name" + def test_create_with_real_payload( self, mock_create_cached_content, mock_get_cached_content ): @@ -162,3 +201,9 @@ def test_create_with_real_payload_and_wrapped_type( == f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/model-name" ) assert cache.name == _CREATED_CONTENT_ID + + def test_list(self, mock_list_cached_contents): + cached_contents = caching.CachedContent.list() + for i, cached_content in enumerate(cached_contents): + assert cached_content.name == f"cached_content{i + 1}_from_list_request" + assert cached_content.model_name == f"model-name{i + 1}" diff --git a/vertexai/caching/_caching.py b/vertexai/caching/_caching.py index 00bab94fe8..ccf249fd5d 100644 --- a/vertexai/caching/_caching.py +++ b/vertexai/caching/_caching.py @@ -135,18 +135,7 @@ def __init__(self, cached_content_name: str): "456". """ super().__init__(resource_name=cached_content_name) - - resource_name = aiplatform_utils.full_resource_name( - resource_name=cached_content_name, - resource_noun=self._resource_noun, - parse_resource_name_method=self._parse_resource_name, - format_resource_name_method=self._format_resource_name, - project=self.project, - location=self.location, - parent_resource_name_fields=None, - resource_id_validator=self._resource_id_validator, - ) - self._gca_resource = gca_cached_content.CachedContent(name=resource_name) + self._gca_resource = self._get_gca_resource(cached_content_name) @property def _raw_cached_content(self) -> gca_cached_content.CachedContent: @@ -154,9 +143,7 @@ def _raw_cached_content(self) -> gca_cached_content.CachedContent: @property def model_name(self) -> str: - if not self._raw_cached_content.model: - self._sync_gca_resource() - return self._raw_cached_content.model + return self._gca_resource.model @classmethod def create( @@ -235,6 +222,10 @@ def create( obj._gca_resource = cached_content_resource return obj + def refresh(self): + """Syncs the local cached content with the remote resource.""" + self._sync_gca_resource() + def update( self, *, @@ -265,15 +256,28 @@ def update( @property def expire_time(self) -> datetime.datetime: - """Time this resource was last updated.""" - self._sync_gca_resource() + """Time this resource is considered expired. + + The returned value may be stale. Use refresh() to get the latest value. + + Returns: + The expiration time of the cached content resource. + """ return self._gca_resource.expire_time def delete(self): + """Deletes the current cached content resource.""" self._delete() @classmethod - def list(cls): + def list(cls) -> List["CachedContent"]: + """Lists the active cached content resources.""" # TODO(b/345326114): Make list() interface richer after aligning with # Google AI SDK return cls._list() + + @classmethod + def get(cls, cached_content_name: str) -> "CachedContent": + """Retrieves an existing cached content resource.""" + cache = cls(cached_content_name) + return cache