Skip to content

Commit

Permalink
Allow empty contents with count_tokens (#342)
Browse files Browse the repository at this point in the history
Change-Id: Ic20e2f88427d2e4fbc97847cf5c2df1f80a9a5a1
  • Loading branch information
MarkDaoust authored May 21, 2024
1 parent 88f7ab3 commit 05877f7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
9 changes: 6 additions & 3 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ def _prepare_request(
tool_config: content_types.ToolConfigType | None,
) -> glm.GenerateContentRequest:
"""Creates a `glm.GenerateContentRequest` from raw inputs."""
if not contents:
raise TypeError("contents must not be empty")

tools_lib = self._get_tools_lib(tools)
if tools_lib is not None:
tools_lib = tools_lib.to_proto()
Expand Down Expand Up @@ -235,6 +232,9 @@ def generate_content(
tools: `glm.Tools` more info coming soon.
request_options: Options for the request.
"""
if not contents:
raise TypeError("contents must not be empty")

request = self._prepare_request(
contents=contents,
generation_config=generation_config,
Expand Down Expand Up @@ -282,6 +282,9 @@ async def generate_content_async(
request_options: helper_types.RequestOptionsType | None = None,
) -> generation_types.AsyncGenerateContentResponse:
"""The async version of `GenerativeModel.generate_content`."""
if not contents:
raise TypeError("contents must not be empty")

request = self._prepare_request(
contents=contents,
generation_config=generation_config,
Expand Down
33 changes: 26 additions & 7 deletions tests/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes()


def noop(x: int):
return x


def simple_part(text: str) -> glm.Content:
return glm.Content({"parts": [{"text": text}]})

Expand Down Expand Up @@ -725,18 +729,33 @@ def test_system_instruction(self, instruction, expected_instr):
self.assertEqual(req.system_instruction, expected_instr)

@parameterized.named_parameters(
["basic", "Hello"],
["list", ["Hello"]],
["basic", {"contents": "Hello"}],
["list", {"contents": ["Hello"]}],
[
"list2",
[{"text": "Hello"}, {"inline_data": {"data": b"PNG!", "mime_type": "image/png"}}],
{
"contents": [
{"text": "Hello"},
{"inline_data": {"data": b"PNG!", "mime_type": "image/png"}},
]
},
],
["contents", [{"role": "user", "parts": ["hello"]}]],
[
"contents",
{"contents": [{"role": "user", "parts": ["hello"]}]},
],
["empty", {}],
[
"system_instruction",
{"system_instruction": ["You are a cat"]},
],
["tools", {"tools": [noop]}],
)
def test_count_tokens_smoke(self, contents):
def test_count_tokens_smoke(self, kwargs):
si = kwargs.pop("system_instruction", None)
self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)]
model = generative_models.GenerativeModel("gemini-pro-vision")
response = model.count_tokens(contents)
model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si)
response = model.count_tokens(**kwargs)
self.assertEqual(type(response).to_dict(response), {"total_tokens": 7})

@parameterized.named_parameters(
Expand Down

0 comments on commit 05877f7

Please sign in to comment.