Skip to content

Commit

Permalink
Explicit Caching patch (#377)
Browse files Browse the repository at this point in the history
* Squashed commit of the following:

commit acb3806
Author: Mayuresh Agashe <magashe@google.com>
Date:   Wed Jun 5 00:51:30 2024 +0000

    fix update method

    Change-Id: I433c25b2d80cdf6e483b59f61ff29bb8d2dc6595

commit fb9995c
Merge: 4627fe1 7b9758f
Author: Mark Daoust <markdaoust@google.com>
Date:   Tue Jun 4 09:55:38 2024 -0700

    Merge branch 'main' into caching

    Change-Id: I2bade6b0099f12dd37a24fe26cfda1981c58fbc0

commit 4627fe1
Author: Mark Daoust <markdaoust@google.com>
Date:   Tue Jun 4 09:54:31 2024 -0700

    use preview build

    Change-Id: Ic1cd4fc28f591794dc5fbff0647a00a77ea7f601

commit 8e86ef1
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 30 16:18:22 2024 +0000

    Refactor for genai.protos module

    Change-Id: I2f02d2421d7303f0309ec86f05d33c07332c03c1

commit 82d3c5a
Merge: bf6551a f08c789
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 30 15:57:27 2024 +0000

    Merge branch 'main' of https://github.com/mayureshagashe2105/generative-ai-python into caching

    Change-Id: Id2b259fe4b2c91653bf5e4d5e883f556366d8676

commit bf6551a
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 11:26:03 2024 +0000

    Fix types

    Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075

commit 67472d3
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 11:26:03 2024 +0000

    Fix types

    Change-Id: Id3e7316562f4029e5b7409ae725bb66e2207f075

commit a1c8c72
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 11:15:15 2024 +0000

    Fix docstrings

    Change-Id: I6020df4e862a4f1d58462a4cd70876a8448293cf

commit f48cedc
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 11:13:44 2024 +0000

    Fix types

    Change-Id: Ia4bf6b936fab4c1992798c65cff91c15e51a92c0

commit 645ceab
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 05:54:26 2024 +0000

    blacken

    Change-Id: I4e073d821d29eea30801bdb7e2a8dc01bb7d6b9a

commit 17372e3
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 05:54:06 2024 +0000

    Add 'cached_content' to GenerativeModel's repr

    Change-Id: I06676fad23895e3e1a6393baa938fc1f2df57d80

commit d1fd749
Author: Mayuresh Agashe <magashe@google.com>
Date:   Mon May 27 05:04:43 2024 +0000

    Add type-annotations to __new__ to fix pytype checks

    Change-Id: I6c69c036e54d56d18ea60368fa0a1dcda2d315fd

commit f37df8c
Author: Mayuresh Agashe <magashe@google.com>
Date:   Sun May 26 06:51:54 2024 +0000

    mark name as OPTIONAL for CachedContent creation

    If not provided, the name will be randomly generated

    Change-Id: Ib95fbafd3dfe098b43164d7ee4d6c2a84b0aae2e

commit 59663c8
Author: Mayuresh Agashe <magashe@google.com>
Date:   Fri May 24 10:22:08 2024 +0000

    Add tests

    Change-Id: I249188fa585bd9b7193efa48b1cfca20b8a79821

commit e1d8c7a
Author: Mayuresh Agashe <magashe@google.com>
Date:   Fri May 24 10:21:42 2024 +0000

    Validate name checks for CachedContent creation

    Change-Id: Ie41602621d99ddff6404c6708c7278e0da790652

commit 2cde1a2
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 23 18:09:14 2024 +0000

    fix tests

    Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7

commit d862dae
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 23 18:09:14 2024 +0000

    fix tests

    Change-Id: I39f61012f850a82e09a7afb80b527a0f99ad0ec7

commit d35cc71
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 23 23:12:38 2024 +0530

    Improve tests

commit e65d16e
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 23 23:12:05 2024 +0530

    blacken

commit cfc936e
Author: Mayuresh Agashe <magashe@google.com>
Date:   Thu May 23 23:10:16 2024 +0530

    Stroke out functional approach for CachedContent CURD ops

commit afd066d
Merge: 6fafe6b 0dca4ce
Author: Mayuresh Agashe <magashe@google.com>
Date:   Wed May 22 23:10:20 2024 +0530

    Merge branch 'main' into caching

commit 6fafe6b
Author: Mayuresh Agashe <magashe@google.com>
Date:   Wed May 22 10:49:35 2024 +0530

    rename get_cached_content to get

commit a4ac7a5
Merge: f13228d f987fde
Author: Mayuresh Agashe <magashe@google.com>
Date:   Tue May 21 23:32:41 2024 +0530

    Merge branch 'main' into caching

commit f13228d
Author: Mayuresh Agashe <magashe@google.com>
Date:   Fri Apr 26 16:54:09 2024 +0000

    *Inital prototype for explicit caching

    *Add basic CURD support for caching

    *Remove INPUT_ONLY marked fields from CachedContent dataclass

    *Rename files 'cached_content*' -> 'caching*'

    *Update 'Create' method for explicit instantination of 'CachedContent'

    *Add a factory method to instatinate model with `CachedContent` as
    its context

    *blacken

    *Add tests

    Change-Id: I694545243efda467d6fd599beded0dc6679b727d

Change-Id: I7b14d94f729953294780815f4c496888bb2ad46f

* Remove auto cache deletion

Change-Id: I4658e1c57f967faeb3945dffef0181a456d65370

* Rename _to_dict --> _get_update_fields

Change-Id: I3c92c65e8e5b215e98c1ac0eea6db033166dec78

* Fix tests

Change-Id: Id36d7606e13d15caf6870f29a108944c7f36eaeb

* Set 'CachedContent' as a public property

Remove __new__ construct

Change-Id: Ie4f5527270be90730341b6c3b67de71b9b6e9c5c

* blacken

Change-Id: I12498213a7fc2b257827ab0df87c6913e04cad25

* set 'role=user' when content is passed as a str (#4)

'to_content' method assigns a default 'role=user' to all the contents passed as a string

Change-Id: I748514a7839b7f1d36150b879c3d1464ca9e11ba

* Handle ttl and expire_time separately

Change-Id: If9c6f04fe8d419828e3efd2249f0698bca4d5bdc

* Remove name param

Change-Id: I40fe7c8fafdb014fb9c7e74956452aca9a666641

* Update caching_types.py

* Update caching.py

* Update docstrs and error messages

Change-Id: I111a1218a7d9783d494b84f0a11cb3b76c7ad9da

* Update model name to gemini-1.5-pro for caching tests

Change-Id: Ibb1f75c409afaac124ef70232be71e3a882f6015

* Remove dafault ttl assignment

Let the API set the dafault

Change-Id: Id8d125a085ed27229ddb78d5812ed5b5ad39227b

* blacken

Change-Id: I1d7fe0ec422589e237502b0eda687cf81ef21a21

* Remove client arg

Change-Id: I17f05a90a1514f404dd3527c0db1ce6147d2c47a

* Add 'usage_metadata' param to CachedContent class

Change-Id: Ic527c157bc2cd114948b73a8f1832c21dd61b52e

* Add 'display_name' to CachedContent class

Change-Id: Id0a9be9d1bfdb94dc9d5c4fc7af9dee89e5365a4

* update generativelanguage version, fix tests

Change-Id: I0acc57853ab7dde863bbbe4b30ae3957e6ec3d11

* format

Change-Id: Ib2e9a16aaa989021d3498f3e59f9983560919159

* fewer automatic 'role' insertions

Change-Id: I0752741532a451f8720fa5e110e68f0b4e66cc4b

* cleanup

Change-Id: I151a809f6d079b8e4b0ed30d1153a638c98cacfd

* Wrap the proto

Change-Id: I14b4c54652fb51b867fb43d4b3e9091e6eaccd4e

* Apply suggestions from code review

Co-authored-by: Mayuresh Agashe <magashe@google.com>

* fix

Change-Id: I381029fc8fc13c39e432b39084fc8feba305514e

* format

Change-Id: I8e0b44aebc102d3b2afb27a422c4d70d6c99d5d2

* cleanup

Change-Id: I024733b53cede5bfdf957ce7e56d6ad01fd4b2bf

* update version

Change-Id: Ic95dffb3e945e31adc0d98787942d27289512b8a

* fix

Change-Id: I6ffdabbddf0e803606b3638521ebfeb6796d2e4b

* typing

Change-Id: I629d4d111f0e640f4f4bf602ea33f70fdc9ca3e4

* Simplify update method

Accept kwargs instead of dict of updates and construct protos using kwargs

Change-Id: I7858d585b1aa6b965134e2fb90adff737172af92

* Add repr to CachedContent

Change-Id: Id4ec78ebf9d6e96f22f6bf37fc4509268fa552f4

* cleanup

Change-Id: I684b46f881735bceb3f9e09d8573721ddb29f98a

* blacken

Change-Id: I773e7a5b8a222c8b4435470cdc2b53be425d95e4

* Apply suggestions from code review

Change-Id: I2a12b9689001bbc41c460db5a9f0e87c77d4caf6

---------

Co-authored-by: Mark Daoust <markdaoust@google.com>
  • Loading branch information
mayureshagashe2105 and MarkDaoust authored Jun 13, 2024
1 parent dbd5498 commit 23b81d7
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 214 deletions.
266 changes: 160 additions & 106 deletions google/generativeai/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,90 +14,130 @@
# limitations under the License.
from __future__ import annotations

import dataclasses
import datetime
from typing import Any, Iterable, Optional
import textwrap
from typing import Iterable, Optional

from google.generativeai import protos
from google.generativeai.types.model_types import idecode_time
from google.generativeai.types import caching_types
from google.generativeai.types import content_types
from google.generativeai.utils import flatten_update_paths
from google.generativeai.client import get_default_cache_client

from google.protobuf import field_mask_pb2
import google.ai.generativelanguage as glm

_USER_ROLE = "user"
_MODEL_ROLE = "model"


@dataclasses.dataclass
class CachedContent:
"""Cached content resource."""

name: str
model: str
create_time: datetime.datetime
update_time: datetime.datetime
expire_time: datetime.datetime
def __init__(self, name):
"""Fetches a `CachedContent` resource.
# NOTE: Automatic CachedContent deletion using contextmanager is not P0(P1+).
# Adding basic support for now.
def __enter__(self):
return self
Identical to `CachedContent.get`.
def __exit__(self, exc_type, exc_value, exc_tb):
self.delete()

def _to_dict(self) -> protos.CachedContent:
proto_paths = {
"name": self.name,
"model": self.model,
}
return protos.CachedContent(**proto_paths)

def _apply_update(self, path, value):
parts = path.split(".")
for part in parts[:-1]:
self = getattr(self, part)
if parts[-1] == "ttl":
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
parts[-1] = "expire_time"
setattr(self, parts[-1], value)
Args:
name: The resource name referring to the cached content.
"""
client = get_default_cache_client()

@classmethod
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
# not supposed to get INPUT_ONLY repeated fields, but local gapic lib build
# is returning these, hence setting including_default_value_fields to False
cached_content = type(cached_content).to_dict(
cached_content, including_default_value_fields=False
if "cachedContents/" not in name:
name = "cachedContents/" + name

request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
self._proto = response

@property
def name(self) -> str:
return self._proto.name

@property
def model(self) -> str:
return self._proto.model

@property
def display_name(self) -> str:
return self._proto.display_name

@property
def usage_metadata(self) -> protos.CachedContent.UsageMetadata:
return self._proto.usage_metadata

@property
def create_time(self) -> datetime.datetime:
return self._proto.create_time

@property
def update_time(self) -> datetime.datetime:
return self._proto.update_time

@property
def expire_time(self) -> datetime.datetime:
return self._proto.expire_time

def __str__(self):
return textwrap.dedent(
f"""\
CachedContent(
name='{self.name}',
model='{self.model}',
display_name='{self.display_name}',
usage_metadata={'{'}
'total_token_count': {self.usage_metadata.total_token_count},
{'}'},
create_time={self.create_time},
update_time={self.update_time},
expire_time={self.expire_time}
)"""
)

idecode_time(cached_content, "create_time")
idecode_time(cached_content, "update_time")
# always decode `expire_time` as Timestamp is returned
# regardless of what was sent on input
idecode_time(cached_content, "expire_time")
return cls(**cached_content)
__repr__ = __str__

@classmethod
def _from_obj(cls, obj: CachedContent | protos.CachedContent | dict) -> CachedContent:
"""Creates an instance of CachedContent form an object, without calling `get`."""
self = cls.__new__(cls)
self._proto = protos.CachedContent()
self._update(obj)
return self

def _update(self, updates):
"""Updates this instance inplace, does not call the API's `update` method"""
if isinstance(updates, CachedContent):
updates = updates._proto

if not isinstance(updates, dict):
updates = type(updates).to_dict(updates, including_default_value_fields=False)

for key, value in updates.items():
setattr(self._proto, key, value)

@staticmethod
def _prepare_create_request(
model: str,
name: str | None = None,
*,
display_name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> protos.CreateCachedContentRequest:
"""Prepares a CreateCachedContentRequest."""
if name is not None:
if not caching_types.valid_cached_content_name(name):
raise ValueError(caching_types.NAME_ERROR_MESSAGE.format(name=name))

name = "cachedContents/" + name
if ttl and expire_time:
raise ValueError(
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
)

if "/" not in model:
model = "models/" + model

if display_name and len(display_name) > 128:
raise ValueError("`display_name` must be no more than 128 unicode characters.")

if system_instruction:
system_instruction = content_types.to_content(system_instruction)

Expand All @@ -110,18 +150,21 @@ def _prepare_create_request(

if contents:
contents = content_types.to_contents(contents)
if not contents[-1].role:
contents[-1].role = _USER_ROLE

if ttl:
ttl = caching_types.to_ttl(ttl)
ttl = caching_types.to_optional_ttl(ttl)
expire_time = caching_types.to_optional_expire_time(expire_time)

cached_content = protos.CachedContent(
name=name,
model=model,
display_name=display_name,
system_instruction=system_instruction,
contents=contents,
tools=tools_lib,
tool_config=tool_config,
ttl=ttl,
expire_time=expire_time,
)

return protos.CreateCachedContentRequest(cached_content=cached_content)
Expand All @@ -130,48 +173,55 @@ def _prepare_create_request(
def create(
cls,
model: str,
name: str | None = None,
*,
display_name: str | None = None,
system_instruction: Optional[content_types.ContentType] = None,
contents: Optional[content_types.ContentsType] = None,
tools: Optional[content_types.FunctionLibraryType] = None,
tool_config: Optional[content_types.ToolConfigType] = None,
ttl: Optional[caching_types.ExpirationTypes] = datetime.timedelta(hours=1),
client: glm.CacheServiceClient | None = None,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> CachedContent:
"""Creates `CachedContent` resource.
Args:
model: The name of the `model` to use for cached content creation.
Any `CachedContent` resource can be only used with the
`model` it was created for.
name: The resource name referring to the cached content.
display_name: The user-generated meaningful display name
of the cached content. `display_name` must be no
more than 128 unicode characters.
system_instruction: Developer set system instruction.
contents: Contents to cache.
tools: A list of `Tools` the model may use to generate response.
tool_config: Config to apply to all tools.
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
`ttl` and `expire_time` are exclusive arguments.
expire_time: Expiration time for cached resource.
`ttl` and `expire_time` are exclusive arguments.
Returns:
`CachedContent` resource with specified name.
"""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

request = cls._prepare_create_request(
model=model,
name=name,
display_name=display_name,
system_instruction=system_instruction,
contents=contents,
tools=tools,
tool_config=tool_config,
ttl=ttl,
expire_time=expire_time,
)

response = client.create_cached_content(request)
return cls._decode_cached_content(response)
result = CachedContent._from_obj(response)
return result

@classmethod
def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedContent:
def get(cls, name: str) -> CachedContent:
"""Fetches required `CachedContent` resource.
Args:
Expand All @@ -180,20 +230,18 @@ def get(cls, name: str, client: glm.CacheServiceClient | None = None) -> CachedC
Returns:
`CachedContent` resource with specified `name`.
"""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

if "cachedContents/" not in name:
name = "cachedContents/" + name

request = protos.GetCachedContentRequest(name=name)
response = client.get_cached_content(request)
return cls._decode_cached_content(response)
result = CachedContent._from_obj(response)
return result

@classmethod
def list(
cls, page_size: Optional[int] = 1, client: glm.CacheServiceClient | None = None
) -> Iterable[CachedContent]:
def list(cls, page_size: Optional[int] = 1) -> Iterable[CachedContent]:
"""Lists `CachedContent` objects associated with the project.
Args:
Expand All @@ -203,58 +251,64 @@ def list(
Returns:
A paginated list of `CachedContent` objects.
"""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

request = protos.ListCachedContentsRequest(page_size=page_size)
for cached_content in client.list_cached_contents(request):
yield cls._decode_cached_content(cached_content)
cached_content = CachedContent._from_obj(cached_content)
yield cached_content

def delete(self, client: glm.CachedServiceClient | None = None) -> None:
def delete(self) -> None:
"""Deletes `CachedContent` resource."""
if client is None:
client = get_default_cache_client()
client = get_default_cache_client()

request = protos.DeleteCachedContentRequest(name=self.name)
client.delete_cached_content(request)
return

def update(
self,
updates: dict[str, Any],
client: glm.CacheServiceClient | None = None,
) -> CachedContent:
*,
ttl: Optional[caching_types.TTLTypes] = None,
expire_time: Optional[caching_types.ExpireTimeTypes] = None,
) -> None:
"""Updates requested `CachedContent` resource.
Args:
updates: The list of fields to update. Currently only
`ttl/expire_time` is supported as an update path.
Returns:
`CachedContent` object with specified updates.
ttl: TTL for cached resource (in seconds). Defaults to 1 hour.
`ttl` and `expire_time` are exclusive arguments.
expire_time: Expiration time for cached resource.
`ttl` and `expire_time` are exclusive arguments.
"""
if client is None:
client = get_default_cache_client()

updates = flatten_update_paths(updates)
for update_path in updates:
if update_path == "ttl":
updates = updates.copy()
update_path_val = updates.get(update_path)
updates[update_path] = caching_types.to_ttl(update_path_val)
else:
raise ValueError(
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
)
field_mask = field_mask_pb2.FieldMask()
client = get_default_cache_client()

for path in updates.keys():
field_mask.paths.append(path)
for path, value in updates.items():
self._apply_update(path, value)
if ttl and expire_time:
raise ValueError(
"Exclusive arguments: Please provide either `ttl` or `expire_time`, not both."
)

request = protos.UpdateCachedContentRequest(
cached_content=self._to_dict(), update_mask=field_mask
ttl = caching_types.to_optional_ttl(ttl)
expire_time = caching_types.to_optional_expire_time(expire_time)

updates = protos.CachedContent(
name=self.name,
ttl=ttl,
expire_time=expire_time,
)
client.update_cached_content(request)
return self

field_mask = field_mask_pb2.FieldMask()

if ttl:
field_mask.paths.append("ttl")
elif expire_time:
field_mask.paths.append("expire_time")
else:
raise ValueError(
f"Bad update name: Only `ttl` or `expire_time` can be updated for `CachedContent`."
)

request = protos.UpdateCachedContentRequest(cached_content=updates, update_mask=field_mask)
updated_cc = client.update_cached_content(request)
self._update(updated_cc)

return
Loading

0 comments on commit 23b81d7

Please sign in to comment.