From 2e62faebceaf496d9a511f930f23c579669af5a1 Mon Sep 17 00:00:00 2001 From: Logan Kilpatrick <23kilpatrick23@gmail.com> Date: Sun, 26 May 2024 21:34:06 -0500 Subject: [PATCH 1/2] Update __init__.py to use the latest model (#362) * Update __init__.py * Grammar --------- Co-authored-by: Mark McDonald --- google/generativeai/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 53383a1b3..57e848298 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -30,8 +30,8 @@ genai.configure(api_key=os.environ['API_KEY']) -model = genai.GenerativeModel(name='gemini-pro') -response = model.generate_content('Please summarise this document: ...') +model = genai.GenerativeModel(name='gemini-1.5-flash') +response = model.generate_content('Teach me about how an LLM works') print(response.text) ``` From f08c789741f30e49ecfb822540fd749920d62bcc Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 29 May 2024 20:46:12 -0700 Subject: [PATCH 2/2] Add genai.protos (#354) * Add genai.protos Change-Id: I21cfada033c6ffbed7a20e117e61582fde925f61 * Add genai.protos Change-Id: I9c8473d4ca1a0e92489f145a18ef1abd29af22b3 * test_protos.py Change-Id: I576080fb80cf9dc9345d8bb2178eb4b9ac59ce97 * fix docs + format Change-Id: I5f9aa3f8e3ae780e5cec2078d3eb153157b195fe * fix merge Change-Id: I17014791d966d797b481bca17df69558b23a9a1a * format Change-Id: I51d30f6568640456bcf28db2bd338a58a82346de * Fix client references Change-Id: I4899231706c9624a0f189b22b6f70aeeb4cbea29 * Fix tests Change-Id: I8a636fb634fd079a892cb99170a12c0613887ccf * add import Change-Id: I517171389801ef249cd478f98798181da83bef69 * fix import Change-Id: I8921c0caaa9b902ebde682ead31a2444298c2c9c * Update docstring Change-Id: I1f6b3b9b9521baa8812a908431bf58c623860733 * spelling Change-Id: I0421a35687ed14b1a5ca3b496cafd91514c4de92 * remove unused imports Change-Id: Ifc791796e36668eb473fd0fffea4833b1a062188 * Resolve review coments. Change-Id: Ieb900190f42e883337028ae25da3be819507db4a * Update docstring. Change-Id: I805473f9aaeb04e922a9f66bb5f40716d42fb738 * Fix typo --------- Co-authored-by: Mark McDonald --- docs/build_docs.py | 131 +------------ google/generativeai/__init__.py | 1 + google/generativeai/answer.py | 59 +++--- google/generativeai/client.py | 3 +- google/generativeai/discuss.py | 87 ++++----- google/generativeai/embedding.py | 15 +- google/generativeai/files.py | 8 +- google/generativeai/generative_models.py | 72 +++---- google/generativeai/models.py | 32 ++-- google/generativeai/operations.py | 12 +- google/generativeai/protos.py | 75 ++++++++ google/generativeai/responder.py | 77 ++++---- google/generativeai/retriever.py | 31 ++-- google/generativeai/text.py | 32 ++-- google/generativeai/types/answer_types.py | 4 +- google/generativeai/types/citation_types.py | 6 +- google/generativeai/types/content_types.py | 138 +++++++------- google/generativeai/types/discuss_types.py | 24 +-- google/generativeai/types/file_types.py | 24 +-- google/generativeai/types/generation_types.py | 60 +++--- google/generativeai/types/model_types.py | 42 ++--- .../generativeai/types/palm_safety_types.py | 134 +++++++------- google/generativeai/types/permission_types.py | 35 ++-- google/generativeai/types/retriever_types.py | 175 ++++++++++-------- google/generativeai/types/safety_types.py | 108 +++++------ tests/test_answer.py | 142 ++++++++------ tests/test_client.py | 4 +- tests/test_content.py | 146 +++++++-------- tests/test_discuss.py | 76 ++++---- tests/test_discuss_async.py | 22 +-- tests/test_embedding.py | 21 ++- tests/test_embedding_async.py | 21 ++- tests/test_files.py | 27 +-- tests/test_generation.py | 155 +++++++++------- tests/test_generative_models.py | 124 ++++++------- tests/test_generative_models_async.py | 34 ++-- tests/test_helpers.py | 12 +- tests/test_models.py | 120 ++++++------ tests/test_operations.py | 18 +- tests/test_permission.py | 70 +++---- tests/test_permission_async.py | 70 +++---- tests/test_protos.py | 34 ++++ tests/test_responder.py | 58 +++--- tests/test_retriever.py | 140 +++++++------- tests/test_retriever_async.py | 134 +++++++------- tests/test_safety.py | 14 +- tests/test_text.py | 96 +++++----- 47 files changed, 1499 insertions(+), 1424 deletions(-) create mode 100644 google/generativeai/protos.py create mode 100644 tests/test_protos.py diff --git a/docs/build_docs.py b/docs/build_docs.py index eaa6a1ba4..012cd3441 100644 --- a/docs/build_docs.py +++ b/docs/build_docs.py @@ -44,77 +44,13 @@ # For showing the conditional imports and types in `content_types.py` # grpc must be imported first. typing.TYPE_CHECKING = True -from google import generativeai as palm - +from google import generativeai as genai from tensorflow_docs.api_generator import generate_lib from tensorflow_docs.api_generator import public_api import yaml -glm.__doc__ = """\ -This package, `google.ai.generativelanguage`, is a low-level auto-generated client library for the PaLM API. - -```posix-terminal -pip install google.ai.generativelanguage -``` - -It is built using the same tooling as Google Cloud client libraries, and will be quite familiar if you've used -those before. - -While we encourage Python users to access the PaLM API using the `google.generativeai` package (aka `palm`), -this lower level package is also available. - -Each method in the PaLM API is connected to one of the client classes. Pass your API-key to the class' `client_options` -when initializing a client: - -``` -from google.ai import generativelanguage as glm - -client = glm.DiscussServiceClient( - client_options={'api_key':'YOUR_API_KEY'}) -``` - -To call the api, pass an appropriate request-proto-object. For the `DiscussServiceClient.generate_message` pass -a `generativelanguage.GenerateMessageRequest` instance: - -``` -request = glm.GenerateMessageRequest( - model='models/chat-bison-001', - prompt=glm.MessagePrompt( - messages=[glm.Message(content='Hello!')])) - -client.generate_message(request) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` - -For simplicity: - -* The API methods also accept key-word arguments. -* Anywhere you might pass a proto-object, the library will also accept simple python structures. - -So the following is equivalent to the previous example: - -``` -client.generate_message( - model='models/chat-bison-001', - prompt={'messages':[{'content':'Hello!'}]}) -``` -``` -candidates { - author: "1" - content: "Hello! How can I help you today?" -} -... -``` -""" - HERE = pathlib.Path(__file__).parent PROJECT_SHORT_NAME = "genai" @@ -139,43 +75,6 @@ ) -class MyFilter: - def __init__(self, base_dirs): - self.filter_base_dirs = public_api.FilterBaseDirs(base_dirs) - - def drop_staticmethods(self, parent, children): - parent = dict(parent.__dict__) - for name, value in children: - if not isinstance(parent.get(name, None), staticmethod): - yield name, value - - def __call__(self, path, parent, children): - if any("generativelanguage" in part for part in path) or "generativeai" in path: - children = self.filter_base_dirs(path, parent, children) - children = public_api.explicit_package_contents_filter(path, parent, children) - - if any("generativelanguage" in part for part in path): - if "ServiceClient" in path[-1] or "ServiceAsyncClient" in path[-1]: - children = list(self.drop_staticmethods(parent, children)) - - return children - - -class MyDocGenerator(generate_lib.DocGenerator): - def make_default_filters(self): - return [ - # filter the api. - public_api.FailIfNestedTooDeep(10), - public_api.filter_module_all, - public_api.add_proto_fields, - public_api.filter_private_symbols, - MyFilter(self._base_dir), # Replaces: public_api.FilterBaseDirs(self._base_dir), - public_api.FilterPrivateMap(self._private_map), - public_api.filter_doc_controls_skip, - public_api.ignore_typing, - ] - - def gen_api_docs(): """Generates api docs for the generative-ai package.""" for name in dir(google): @@ -188,11 +87,11 @@ def gen_api_docs(): """ ) - doc_generator = MyDocGenerator( + doc_generator = generate_lib.DocGenerator( root_title=PROJECT_FULL_NAME, - py_modules=[("google", google)], + py_modules=[("google.generativeai", genai)], base_dir=( - pathlib.Path(palm.__file__).parent, + pathlib.Path(genai.__file__).parent, pathlib.Path(glm.__file__).parent.parent, ), code_url_prefix=( @@ -201,32 +100,12 @@ def gen_api_docs(): ), search_hints=_SEARCH_HINTS.value, site_path=_SITE_PATH.value, - callbacks=[], + callbacks=[public_api.explicit_package_contents_filter], ) out_path = pathlib.Path(_OUTPUT_DIR.value) doc_generator.build(out_path) - # Fixup the toc file. - toc_path = out_path / "google/_toc.yaml" - toc = yaml.safe_load(toc_path.read_text()) - assert toc["toc"][0]["title"] == "google" - toc["toc"] = toc["toc"][1:] - toc["toc"][0]["title"] = "google.ai.generativelanguage" - toc["toc"][0]["section"] = toc["toc"][0]["section"][1]["section"] - toc["toc"][0], toc["toc"][1] = toc["toc"][1], toc["toc"][0] - toc_path.write_text(yaml.dump(toc)) - - # remove some dummy files and redirect them to `api/` - (out_path / "google.md").unlink() - (out_path / "google/ai.md").unlink() - redirects_path = out_path / "_redirects.yaml" - redirects = {"redirects": []} - redirects["redirects"].insert(0, {"from": "/api/python/google/ai", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python/google", "to": "/api/"}) - redirects["redirects"].insert(0, {"from": "/api/python", "to": "/api/"}) - redirects_path.write_text(yaml.dump(redirects)) - # clear `oneof` junk from proto pages for fpath in out_path.rglob("*.md"): old_content = fpath.read_text() diff --git a/google/generativeai/__init__.py b/google/generativeai/__init__.py index 57e848298..2b93fc1ce 100644 --- a/google/generativeai/__init__.py +++ b/google/generativeai/__init__.py @@ -42,6 +42,7 @@ from google.generativeai import version +from google.generativeai import protos from google.generativeai import types from google.generativeai.types import GenerationConfig diff --git a/google/generativeai/answer.py b/google/generativeai/answer.py index 4b9d9f97c..4bfabbf23 100644 --- a/google/generativeai/answer.py +++ b/google/generativeai/answer.py @@ -21,6 +21,7 @@ from typing_extensions import TypedDict import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import ( get_default_generative_client, @@ -35,7 +36,7 @@ DEFAULT_ANSWER_MODEL = "models/aqa" -AnswerStyle = glm.GenerateAnswerRequest.AnswerStyle +AnswerStyle = protos.GenerateAnswerRequest.AnswerStyle AnswerStyleOptions = Union[int, str, AnswerStyle] @@ -66,28 +67,30 @@ def to_answer_style(x: AnswerStyleOptions) -> AnswerStyle: GroundingPassageOptions = ( - Union[glm.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType], + Union[ + protos.GroundingPassage, tuple[str, content_types.ContentType], content_types.ContentType + ], ) GroundingPassagesOptions = Union[ - glm.GroundingPassages, + protos.GroundingPassages, Iterable[GroundingPassageOptions], Mapping[str, content_types.ContentType], ] -def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingPassages: +def _make_grounding_passages(source: GroundingPassagesOptions) -> protos.GroundingPassages: """ - Converts the `source` into a `glm.GroundingPassage`. A `GroundingPassages` contains a list of - `glm.GroundingPassage` objects, which each contain a `glm.Contant` and a string `id`. + Converts the `source` into a `protos.GroundingPassage`. A `GroundingPassages` contains a list of + `protos.GroundingPassage` objects, which each contain a `protos.Contant` and a string `id`. Args: - source: `Content` or a `GroundingPassagesOptions` that will be converted to glm.GroundingPassages. + source: `Content` or a `GroundingPassagesOptions` that will be converted to protos.GroundingPassages. Return: - `glm.GroundingPassages` to be passed into `glm.GenerateAnswer`. + `protos.GroundingPassages` to be passed into `protos.GenerateAnswer`. """ - if isinstance(source, glm.GroundingPassages): + if isinstance(source, protos.GroundingPassages): return source if not isinstance(source, Iterable): @@ -100,7 +103,7 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP source = source.items() for n, data in enumerate(source): - if isinstance(data, glm.GroundingPassage): + if isinstance(data, protos.GroundingPassage): passages.append(data) elif isinstance(data, tuple): id, content = data # tuple must have exactly 2 items. @@ -108,11 +111,11 @@ def _make_grounding_passages(source: GroundingPassagesOptions) -> glm.GroundingP else: passages.append({"id": str(n), "content": content_types.to_content(data)}) - return glm.GroundingPassages(passages=passages) + return protos.GroundingPassages(passages=passages) SourceNameType = Union[ - str, retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document + str, retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document ] @@ -127,7 +130,7 @@ class SemanticRetrieverConfigDict(TypedDict): SemanticRetrieverConfigOptions = Union[ SourceNameType, SemanticRetrieverConfigDict, - glm.SemanticRetrieverConfig, + protos.SemanticRetrieverConfig, ] @@ -135,7 +138,7 @@ def _maybe_get_source_name(source) -> str | None: if isinstance(source, str): return source elif isinstance( - source, (retriever_types.Corpus, glm.Corpus, retriever_types.Document, glm.Document) + source, (retriever_types.Corpus, protos.Corpus, retriever_types.Document, protos.Document) ): return source.name else: @@ -145,8 +148,8 @@ def _maybe_get_source_name(source) -> str | None: def _make_semantic_retriever_config( source: SemanticRetrieverConfigOptions, query: content_types.ContentsType, -) -> glm.SemanticRetrieverConfig: - if isinstance(source, glm.SemanticRetrieverConfig): +) -> protos.SemanticRetrieverConfig: + if isinstance(source, protos.SemanticRetrieverConfig): return source name = _maybe_get_source_name(source) @@ -156,7 +159,7 @@ def _make_semantic_retriever_config( source["source"] = _maybe_get_source_name(source["source"]) else: raise TypeError( - f"Invalid input: Failed to create a 'glm.SemanticRetrieverConfig' from the provided source. " + f"Invalid input: Failed to create a 'protos.SemanticRetrieverConfig' from the provided source. " f"Received type: {type(source).__name__}, " f"Received value: {source}" ) @@ -166,7 +169,7 @@ def _make_semantic_retriever_config( elif isinstance(source["query"], str): source["query"] = content_types.to_content(source["query"]) - return glm.SemanticRetrieverConfig(source) + return protos.SemanticRetrieverConfig(source) def _make_generate_answer_request( @@ -178,9 +181,9 @@ def _make_generate_answer_request( answer_style: AnswerStyle | None = None, safety_settings: safety_types.SafetySettingOptions | None = None, temperature: float | None = None, -) -> glm.GenerateAnswerRequest: +) -> protos.GenerateAnswerRequest: """ - constructs a glm.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. + constructs a protos.GenerateAnswerRequest object by organizing the input parameters for the API call to generate a grounded answer from the model. Args: model: Name of the model used to generate the grounded response. @@ -188,16 +191,16 @@ def _make_generate_answer_request( single question to answer. For multi-turn queries, this is a repeated field that contains conversation history and the last `Content` in the list containing the question. inline_passages: Grounding passages (a list of `Content`-like objects or `(id, content)` pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style for grounded answers. safety_settings: Safety settings for generated output. temperature: The temperature for randomness in the output. Returns: - Call for glm.GenerateAnswerRequest(). + Call for protos.GenerateAnswerRequest(). """ model = model_types.make_model_name(model) @@ -224,7 +227,7 @@ def _make_generate_answer_request( if answer_style: answer_style = to_answer_style(answer_style) - return glm.GenerateAnswerRequest( + return protos.GenerateAnswerRequest( model=model, contents=contents, inline_passages=inline_passages, @@ -273,9 +276,9 @@ def generate_answer( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. @@ -327,9 +330,9 @@ async def generate_answer_async( contents: The question to be answered by the model, grounded in the provided source. inline_passages: Grounding passages (a list of `Content`-like objects or (id, content) pairs, - or a `glm.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, + or a `protos.GroundingPassages`) to send inline with the request. Exclusive with `semantic_retreiver`, one must be set, but not both. - semantic_retriever: A Corpus, Document, or `glm.SemanticRetrieverConfig` to use for grounding. Exclusive with + semantic_retriever: A Corpus, Document, or `protos.SemanticRetrieverConfig` to use for grounding. Exclusive with `inline_passages`, one must be set, but not both. answer_style: Style in which the grounded answer should be returned. safety_settings: Safety settings for generated output. Defaults to None. diff --git a/google/generativeai/client.py b/google/generativeai/client.py index d969889d0..40c2bdcaf 100644 --- a/google/generativeai/client.py +++ b/google/generativeai/client.py @@ -10,6 +10,7 @@ import httplib2 import google.ai.generativelanguage as glm +import google.generativeai.protos as protos from google.auth import credentials as ga_credentials from google.auth import exceptions as ga_exceptions @@ -76,7 +77,7 @@ def create_file( name: str | None = None, display_name: str | None = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: if self._discovery_api is None: self._setup_discovery_api() diff --git a/google/generativeai/discuss.py b/google/generativeai/discuss.py index b084ccad8..448347b41 100644 --- a/google/generativeai/discuss.py +++ b/google/generativeai/discuss.py @@ -18,37 +18,38 @@ import sys import textwrap -from typing import Any, Iterable, List, Optional, Union +from typing import Iterable, List import google.ai.generativelanguage as glm from google.generativeai.client import get_default_discuss_client from google.generativeai.client import get_default_discuss_async_client from google.generativeai import string_utils +from google.generativeai import protos from google.generativeai.types import discuss_types from google.generativeai.types import helper_types from google.generativeai.types import model_types from google.generativeai.types import palm_safety_types -def _make_message(content: discuss_types.MessageOptions) -> glm.Message: - """Creates a `glm.Message` object from the provided content.""" - if isinstance(content, glm.Message): +def _make_message(content: discuss_types.MessageOptions) -> protos.Message: + """Creates a `protos.Message` object from the provided content.""" + if isinstance(content, protos.Message): return content if isinstance(content, str): - return glm.Message(content=content) + return protos.Message(content=content) else: - return glm.Message(content) + return protos.Message(content) def _make_messages( messages: discuss_types.MessagesOptions, -) -> List[glm.Message]: +) -> List[protos.Message]: """ - Creates a list of `glm.Message` objects from the provided messages. + Creates a list of `protos.Message` objects from the provided messages. This function takes a variety of message content inputs, such as strings, dictionaries, - or `glm.Message` objects, and creates a list of `glm.Message` objects. It ensures that + or `protos.Message` objects, and creates a list of `protos.Message` objects. It ensures that the authors of the messages alternate appropriately. If authors are not provided, default authors are assigned based on their position in the list. @@ -56,9 +57,9 @@ def _make_messages( messages: The messages to convert. Returns: - A list of `glm.Message` objects with alternating authors. + A list of `protos.Message` objects with alternating authors. """ - if isinstance(messages, (str, dict, glm.Message)): + if isinstance(messages, (str, dict, protos.Message)): messages = [_make_message(messages)] else: messages = [_make_message(message) for message in messages] @@ -93,39 +94,39 @@ def _make_messages( return messages -def _make_example(item: discuss_types.ExampleOptions) -> glm.Example: - """Creates a `glm.Example` object from the provided item.""" - if isinstance(item, glm.Example): +def _make_example(item: discuss_types.ExampleOptions) -> protos.Example: + """Creates a `protos.Example` object from the provided item.""" + if isinstance(item, protos.Example): return item if isinstance(item, dict): item = item.copy() item["input"] = _make_message(item["input"]) item["output"] = _make_message(item["output"]) - return glm.Example(item) + return protos.Example(item) if isinstance(item, Iterable): input, output = list(item) - return glm.Example(input=_make_message(input), output=_make_message(output)) + return protos.Example(input=_make_message(input), output=_make_message(output)) # try anyway - return glm.Example(item) + return protos.Example(item) def _make_examples_from_flat( examples: List[discuss_types.MessageOptions], -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from a list of message options. + Creates a list of `protos.Example` objects from a list of message options. This function takes a list of `discuss_types.MessageOptions` and pairs them into - `glm.Example` objects. The input examples must be in pairs to create valid examples. + `protos.Example` objects. The input examples must be in pairs to create valid examples. Args: examples: The list of `discuss_types.MessageOptions`. Returns: - A list of `glm.Example objects` created by pairing up the provided messages. + A list of `protos.Example objects` created by pairing up the provided messages. Raises: ValueError: If the provided list of examples is not of even length. @@ -145,7 +146,7 @@ def _make_examples_from_flat( pair.append(msg) if n % 2 == 0: continue - primer = glm.Example( + primer = protos.Example( input=pair[0], output=pair[1], ) @@ -156,21 +157,21 @@ def _make_examples_from_flat( def _make_examples( examples: discuss_types.ExamplesOptions, -) -> List[glm.Example]: +) -> List[protos.Example]: """ - Creates a list of `glm.Example` objects from the provided examples. + Creates a list of `protos.Example` objects from the provided examples. This function takes various types of example content inputs and creates a list - of `glm.Example` objects. It handles the conversion of different input types and ensures + of `protos.Example` objects. It handles the conversion of different input types and ensures the appropriate structure for creating valid examples. Args: examples: The examples to convert. Returns: - A list of `glm.Example` objects created from the provided examples. + A list of `protos.Example` objects created from the provided examples. """ - if isinstance(examples, glm.Example): + if isinstance(examples, protos.Example): return [examples] if isinstance(examples, dict): @@ -208,11 +209,11 @@ def _make_message_prompt_dict( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: +) -> protos.MessagePrompt: """ - Creates a `glm.MessagePrompt` object from the provided prompt components. + Creates a `protos.MessagePrompt` object from the provided prompt components. - This function constructs a `glm.MessagePrompt` object using the provided `context`, `examples`, + This function constructs a `protos.MessagePrompt` object using the provided `context`, `examples`, or `messages`. It ensures the proper structure and handling of the input components. Either pass a `prompt` or it's component `context`, `examples`, `messages`. @@ -224,7 +225,7 @@ def _make_message_prompt_dict( messages: The messages for the prompt. Returns: - A `glm.MessagePrompt` object created from the provided prompt components. + A `protos.MessagePrompt` object created from the provided prompt components. """ if prompt is None: prompt = dict( @@ -238,7 +239,7 @@ def _make_message_prompt_dict( raise ValueError( "Invalid configuration: Either `prompt` or its fields `(context, examples, messages)` should be set, but not both simultaneously." ) - if isinstance(prompt, glm.MessagePrompt): + if isinstance(prompt, protos.MessagePrompt): return prompt elif isinstance(prompt, dict): # Always check dict before Iterable. pass @@ -268,12 +269,12 @@ def _make_message_prompt( context: str | None = None, examples: discuss_types.ExamplesOptions | None = None, messages: discuss_types.MessagesOptions | None = None, -) -> glm.MessagePrompt: - """Creates a `glm.MessagePrompt` object from the provided prompt components.""" +) -> protos.MessagePrompt: + """Creates a `protos.MessagePrompt` object from the provided prompt components.""" prompt = _make_message_prompt_dict( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.MessagePrompt(prompt) + return protos.MessagePrompt(prompt) def _make_generate_message_request( @@ -287,15 +288,15 @@ def _make_generate_message_request( top_p: float | None = None, top_k: float | None = None, prompt: discuss_types.MessagePromptOptions | None = None, -) -> glm.GenerateMessageRequest: - """Creates a `glm.GenerateMessageRequest` object for generating messages.""" +) -> protos.GenerateMessageRequest: + """Creates a `protos.GenerateMessageRequest` object for generating messages.""" model = model_types.make_model_name(model) prompt = _make_message_prompt( prompt=prompt, context=context, examples=examples, messages=messages ) - return glm.GenerateMessageRequest( + return protos.GenerateMessageRequest( model=model, prompt=prompt, temperature=temperature, @@ -514,9 +515,9 @@ async def reply_async( def _build_chat_response( - request: glm.GenerateMessageRequest, - response: glm.GenerateMessageResponse, - client: glm.DiscussServiceClient | glm.DiscussServiceAsyncClient, + request: protos.GenerateMessageRequest, + response: protos.GenerateMessageResponse, + client: glm.DiscussServiceClient | protos.DiscussServiceAsyncClient, ) -> ChatResponse: request = type(request).to_dict(request) prompt = request.pop("prompt") @@ -541,7 +542,7 @@ def _build_chat_response( def _generate_response( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, client: glm.DiscussServiceClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: @@ -557,7 +558,7 @@ def _generate_response( async def _generate_response_async( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, client: glm.DiscussServiceAsyncClient | None = None, request_options: helper_types.RequestOptionsType | None = None, ) -> ChatResponse: diff --git a/google/generativeai/embedding.py b/google/generativeai/embedding.py index 8218ec11d..616fa07bf 100644 --- a/google/generativeai/embedding.py +++ b/google/generativeai/embedding.py @@ -18,6 +18,7 @@ from typing import Any, Iterable, overload, TypeVar, Union, Mapping import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_generative_client from google.generativeai.client import get_default_generative_async_client @@ -30,7 +31,7 @@ DEFAULT_EMB_MODEL = "models/embedding-001" EMBEDDING_MAX_BATCH_SIZE = 100 -EmbeddingTaskType = glm.TaskType +EmbeddingTaskType = protos.TaskType EmbeddingTaskTypeOptions = Union[int, str, EmbeddingTaskType] @@ -183,7 +184,7 @@ def embed_content( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -193,7 +194,7 @@ def embed_content( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = client.batch_embed_contents( embedding_request, **request_options, @@ -202,7 +203,7 @@ def embed_content( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, @@ -276,7 +277,7 @@ async def embed_content_async( if isinstance(content, Iterable) and not isinstance(content, (str, Mapping)): result = {"embedding": []} requests = ( - glm.EmbedContentRequest( + protos.EmbedContentRequest( model=model, content=content_types.to_content(c), task_type=task_type, @@ -286,7 +287,7 @@ async def embed_content_async( for c in content ) for batch in _batched(requests, EMBEDDING_MAX_BATCH_SIZE): - embedding_request = glm.BatchEmbedContentsRequest(model=model, requests=batch) + embedding_request = protos.BatchEmbedContentsRequest(model=model, requests=batch) embedding_response = await client.batch_embed_contents( embedding_request, **request_options, @@ -295,7 +296,7 @@ async def embed_content_async( result["embedding"].extend(e["values"] for e in embedding_dict["embeddings"]) return result else: - embedding_request = glm.EmbedContentRequest( + embedding_request = protos.EmbedContentRequest( model=model, content=content_types.to_content(content), task_type=task_type, diff --git a/google/generativeai/files.py b/google/generativeai/files.py index 386592225..4028d37f7 100644 --- a/google/generativeai/files.py +++ b/google/generativeai/files.py @@ -19,7 +19,7 @@ import mimetypes from typing import Iterable import logging -import google.ai.generativelanguage as glm +from google.generativeai import protos from itertools import islice from google.generativeai.types import file_types @@ -76,7 +76,7 @@ def list_files(page_size=100) -> Iterable[file_types.File]: """Calls the API to list files using a supported file service.""" client = get_default_file_client() - response = client.list_files(glm.ListFilesRequest(page_size=page_size)) + response = client.list_files(protos.ListFilesRequest(page_size=page_size)) for proto in response: yield file_types.File(proto) @@ -89,8 +89,8 @@ def get_file(name) -> file_types.File: def delete_file(name): """Calls the API to permanently delete a specified file using a supported file service.""" - if isinstance(name, (file_types.File, glm.File)): + if isinstance(name, (file_types.File, protos.File)): name = name.name - request = glm.DeleteFileRequest(name=name) + request = protos.DeleteFileRequest(name=name) client = get_default_file_client() client.delete_file(request=request) diff --git a/google/generativeai/generative_models.py b/google/generativeai/generative_models.py index 873d2fcb4..7d69ae8f9 100644 --- a/google/generativeai/generative_models.py +++ b/google/generativeai/generative_models.py @@ -11,7 +11,7 @@ import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client from google.generativeai.types import content_types from google.generativeai.types import generation_types @@ -125,8 +125,8 @@ def _prepare_request( safety_settings: safety_types.SafetySettingOptions | None = None, tools: content_types.FunctionLibraryType | None, tool_config: content_types.ToolConfigType | None, - ) -> glm.GenerateContentRequest: - """Creates a `glm.GenerateContentRequest` from raw inputs.""" + ) -> protos.GenerateContentRequest: + """Creates a `protos.GenerateContentRequest` from raw inputs.""" tools_lib = self._get_tools_lib(tools) if tools_lib is not None: tools_lib = tools_lib.to_proto() @@ -147,7 +147,7 @@ def _prepare_request( merged_ss.update(safety_settings) merged_ss = safety_types.normalize_safety_settings(merged_ss) - return glm.GenerateContentRequest( + return protos.GenerateContentRequest( model=self._model_name, contents=contents, generation_config=merged_gc, @@ -209,25 +209,25 @@ def generate_content( ### Input type flexibility - While the underlying API strictly expects a `list[glm.Content]` objects, this method + While the underlying API strictly expects a `list[protos.Content]` objects, this method will convert the user input into the correct type. The hierarchy of types that can be converted is below. Any of these objects can be passed as an equivalent `dict`. - * `Iterable[glm.Content]` - * `glm.Content` - * `Iterable[glm.Part]` - * `glm.Part` - * `str`, `Image`, or `glm.Blob` + * `Iterable[protos.Content]` + * `protos.Content` + * `Iterable[protos.Part]` + * `protos.Part` + * `str`, `Image`, or `protos.Blob` - In an `Iterable[glm.Content]` each `content` is a separate message. - But note that an `Iterable[glm.Part]` is taken as the parts of a single message. + In an `Iterable[protos.Content]` each `content` is a separate message. + But note that an `Iterable[protos.Part]` is taken as the parts of a single message. Arguments: contents: The contents serving as the model's prompt. generation_config: Overrides for the model's generation config. safety_settings: Overrides for the model's safety settings. stream: If True, yield response chunks as they are generated. - tools: `glm.Tools` more info coming soon. + tools: `protos.Tools` more info coming soon. request_options: Options for the request. """ if not contents: @@ -328,14 +328,14 @@ def count_tokens( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, - ) -> glm.CountTokensResponse: + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._client is None: self._client = client.get_default_generative_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -355,14 +355,14 @@ async def count_tokens_async( tools: content_types.FunctionLibraryType | None = None, tool_config: content_types.ToolConfigType | None = None, request_options: helper_types.RequestOptionsType | None = None, - ) -> glm.CountTokensResponse: + ) -> protos.CountTokensResponse: if request_options is None: request_options = {} if self._async_client is None: self._async_client = client.get_default_generative_async_client() - request = glm.CountTokensRequest( + request = protos.CountTokensRequest( model=self.model_name, generate_content_request=self._prepare_request( contents=contents, @@ -388,7 +388,7 @@ def start_chat( >>> response = chat.send_message("Hello?") Arguments: - history: An iterable of `glm.Content` objects, or equivalents to initialize the session. + history: An iterable of `protos.Content` objects, or equivalents to initialize the session. """ if self._generation_config.get("candidate_count", 1) > 1: raise ValueError( @@ -430,8 +430,8 @@ def __init__( enable_automatic_function_calling: bool = False, ): self.model: GenerativeModel = model - self._history: list[glm.Content] = content_types.to_contents(history) - self._last_sent: glm.Content | None = None + self._history: list[protos.Content] = content_types.to_contents(history) + self._last_sent: protos.Content | None = None self._last_received: generation_types.BaseGenerateContentResponse | None = None self.enable_automatic_function_calling = enable_automatic_function_calling @@ -535,13 +535,13 @@ def _check_response(self, *, response, stream): if not stream: if response.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): raise generation_types.StopCandidateException(response.candidates[0]) - def _get_function_calls(self, response) -> list[glm.FunctionCall]: + def _get_function_calls(self, response) -> list[protos.FunctionCall]: candidates = response.candidates if len(candidates) != 1: raise ValueError( @@ -561,14 +561,14 @@ def _handle_afc( stream, tools_lib, request_options, - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( @@ -577,7 +577,7 @@ def _handle_afc( ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = self.model.generate_content( @@ -668,14 +668,14 @@ async def _handle_afc_async( stream, tools_lib, request_options, - ) -> tuple[list[glm.Content], glm.Content, generation_types.BaseGenerateContentResponse]: + ) -> tuple[list[protos.Content], protos.Content, generation_types.BaseGenerateContentResponse]: while function_calls := self._get_function_calls(response): if not all(callable(tools_lib[fc]) for fc in function_calls): break history.append(response.candidates[0].content) - function_response_parts: list[glm.Part] = [] + function_response_parts: list[protos.Part] = [] for fc in function_calls: fr = tools_lib(fc) assert fr is not None, ( @@ -684,7 +684,7 @@ async def _handle_afc_async( ) function_response_parts.append(fr) - send = glm.Content(role=self._USER_ROLE, parts=function_response_parts) + send = protos.Content(role=self._USER_ROLE, parts=function_response_parts) history.append(send) response = await self.model.generate_content_async( @@ -708,7 +708,7 @@ def __copy__(self): history=list(self.history), ) - def rewind(self) -> tuple[glm.Content, glm.Content]: + def rewind(self) -> tuple[protos.Content, protos.Content]: """Removes the last request/response pair from the chat history.""" if self._last_received is None: result = self._history.pop(-2), self._history.pop() @@ -725,16 +725,16 @@ def last(self) -> generation_types.BaseGenerateContentResponse | None: return self._last_received @property - def history(self) -> list[glm.Content]: + def history(self) -> list[protos.Content]: """The chat history.""" last = self._last_received if last is None: return self._history if last.candidates[0].finish_reason not in ( - glm.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, - glm.Candidate.FinishReason.STOP, - glm.Candidate.FinishReason.MAX_TOKENS, + protos.Candidate.FinishReason.FINISH_REASON_UNSPECIFIED, + protos.Candidate.FinishReason.STOP, + protos.Candidate.FinishReason.MAX_TOKENS, ): error = generation_types.StopCandidateException(last.candidates[0]) last._error = error @@ -770,7 +770,7 @@ def __repr__(self) -> str: _model = str(self.model).replace("\n", "\n" + " " * 4) def content_repr(x): - return f"glm.Content({_dict_repr.repr(type(x).to_dict(x))})" + return f"protos.Content({_dict_repr.repr(type(x).to_dict(x))})" try: history = list(self.history) diff --git a/google/generativeai/models.py b/google/generativeai/models.py index 1f9e836e7..9ba0745c1 100644 --- a/google/generativeai/models.py +++ b/google/generativeai/models.py @@ -18,6 +18,8 @@ from typing import Any, Literal import google.ai.generativelanguage as glm + +from google.generativeai import protos from google.generativeai import operations from google.generativeai.client import get_default_model_client from google.generativeai.types import model_types @@ -155,16 +157,16 @@ def get_base_model_name( base_model = model.base_model elif isinstance(model, model_types.Model): base_model = model.name - elif isinstance(model, glm.Model): + elif isinstance(model, protos.Model): base_model = model.name - elif isinstance(model, glm.TunedModel): + elif isinstance(model, protos.TunedModel): base_model = getattr(model, "base_model", None) if not base_model: base_model = model.tuned_model_source.base_model else: raise TypeError( f"Invalid model: The provided model '{model}' is not recognized or supported. " - "Supported types are: str, model_types.TunedModel, model_types.Model, glm.Model, and glm.TunedModel." + "Supported types are: str, model_types.TunedModel, model_types.Model, protos.Model, and protos.TunedModel." ) return base_model @@ -282,9 +284,9 @@ def create_tuned_model( Args: source_model: The name of the model to tune. training_data: The dataset to tune the model on. This must be either: - * A `glm.Dataset`, or + * A `protos.Dataset`, or * An `Iterable` of: - *`glm.TuningExample`, + *`protos.TuningExample`, * `{'text_input': text_input, 'output': output}` dicts * `(text_input, output)` tuples. * A `Mapping` of `Iterable[str]` - use `input_key` and `output_key` to choose which @@ -339,17 +341,17 @@ def create_tuned_model( training_data, input_key=input_key, output_key=output_key ) - hyperparameters = glm.Hyperparameters( + hyperparameters = protos.Hyperparameters( epoch_count=epoch_count, batch_size=batch_size, learning_rate=learning_rate, ) - tuning_task = glm.TuningTask( + tuning_task = protos.TuningTask( training_data=training_data, hyperparameters=hyperparameters, ) - tuned_model = glm.TunedModel( + tuned_model = protos.TunedModel( **source_model, display_name=display_name, description=description, @@ -368,7 +370,7 @@ def create_tuned_model( @typing.overload def update_tuned_model( - tuned_model: glm.TunedModel, + tuned_model: protos.TunedModel, updates: None = None, *, client: glm.ModelServiceClient | None = None, @@ -389,7 +391,7 @@ def update_tuned_model( def update_tuned_model( - tuned_model: str | glm.TunedModel, + tuned_model: str | protos.TunedModel, updates: dict[str, Any] | None = None, *, client: glm.ModelServiceClient | None = None, @@ -418,10 +420,11 @@ def update_tuned_model( field_mask.paths.append(path) for path, value in updates.items(): _apply_update(tuned_model, path, value) - elif isinstance(tuned_model, glm.TunedModel): + elif isinstance(tuned_model, protos.TunedModel): if updates is not None: raise ValueError( - "Invalid argument: When calling `update_tuned_model(tuned_model:glm.TunedModel, updates=None)`, the `updates` argument must not be set." + "Invalid argument: When calling `update_tuned_model(tuned_model:protos.TunedModel, updates=None)`, " + "the `updates` argument must not be set." ) name = tuned_model.name @@ -429,11 +432,12 @@ def update_tuned_model( field_mask = protobuf_helpers.field_mask(was._pb, tuned_model._pb) else: raise TypeError( - f"Invalid argument type: In the function `update_tuned_model(tuned_model:dict|glm.TunedModel)`, the `tuned_model` argument must be of type `dict` or `glm.TunedModel`. Received type: {type(tuned_model).__name__}." + "Invalid argument type: In the function `update_tuned_model(tuned_model:dict|protos.TunedModel)`, the " + f"`tuned_model` argument must be of type `dict` or `protos.TunedModel`. Received type: {type(tuned_model).__name__}." ) result = client.update_tuned_model( - glm.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), + protos.UpdateTunedModelRequest(tuned_model=tuned_model, update_mask=field_mask), **request_options, ) return model_types.decode_tuned_model(result) diff --git a/google/generativeai/operations.py b/google/generativeai/operations.py index 01c0a6b14..52fd8a1b8 100644 --- a/google/generativeai/operations.py +++ b/google/generativeai/operations.py @@ -17,7 +17,7 @@ import functools from typing import Iterator -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai.types import model_types @@ -75,8 +75,8 @@ def from_proto(cls, proto, client): cls=CreateTunedModelOperation, operation=proto, operations_client=client, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) @classmethod @@ -111,14 +111,14 @@ def update(self): """Refresh the current statuses in metadata/result/error""" self._refresh_and_update() - def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: + def wait_bar(self, **kwargs) -> Iterator[protos.CreateTunedModelMetadata]: """A tqdm wait bar, yields `Operation` statuses until complete. Args: **kwargs: passed through to `tqdm.auto.tqdm(..., **kwargs)` Yields: - Operation statuses as `glm.CreateTunedModelMetadata` objects. + Operation statuses as `protos.CreateTunedModelMetadata` objects. """ bar = tqdm.tqdm(total=self.metadata.total_steps, initial=0, **kwargs) @@ -131,7 +131,7 @@ def wait_bar(self, **kwargs) -> Iterator[glm.CreateTunedModelMetadata]: bar.update(self.metadata.completed_steps - bar.n) return self.result() - def set_result(self, result: glm.TunedModel): + def set_result(self, result: protos.TunedModel): result = model_types.decode_tuned_model(result) super().set_result(result) diff --git a/google/generativeai/protos.py b/google/generativeai/protos.py new file mode 100644 index 000000000..010396c75 --- /dev/null +++ b/google/generativeai/protos.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module provides low level access to the ProtoBuffer "Message" classes used by the API. + +**For typical usage of this SDK you do not need to use any of these classes.** + +ProtoBufers are Google API's serilization format. They are strongly typed and efficient. + +The `genai` SDK tries to be permissive about what objects it will accept from a user, but in the end +the SDK always converts input to an appropriate Proto Message object to send as the request. Each API request +has a `*Request` and `*Response` Message defined here. + +If you have any uncertainty about what the API may accept or return, these classes provide the +complete/unambiguous answer. They come from the `google-ai-generativelanguage` package which is +generated from a snapshot of the API definition. + +>>> from google.generativeai import protos +>>> import inspect +>>> print(inspect.getsource(protos.Part)) + +Proto classes can have "oneof" fields. Use `in` to check which `oneof` field is set. + +>>> p = protos.Part(text='hello') +>>> 'text' in p +True +>>> p.inline_data = {'mime_type':'image/png', 'data': b'PNG'} +>>> type(p.inline_data) is protos.Blob +True +>>> 'inline_data' in p +True +>>> 'text' in p +False + +Instances of all Message classes can be converted into JSON compatible dictionaries with the following construct +(Bytes are base64 encoded): + +>>> p_dict = type(p).to_dict(p) +>>> p_dict +{'inline_data': {'mime_type': 'image/png', 'data': 'UE5H'}} + +A compatible dict can be converted to an instance of a Message class by passing it as the first argument to the +constructor: + +>>> p = protos.Part(p_dict) +inline_data { + mime_type: "image/png" + data: "PNG" +} + +Note when converting that `to_dict` accepts additional arguments: + +- `use_integers_for_enums:bool = True`, Set it to `False` to replace enum int values with their string + names in the output +- ` including_default_value_fields:bool = True`, Set it to `False` to reduce the verbosity of the output. + +Additional arguments are described in the docstring: + +>>> help(proto.Part.to_dict) +""" + +from google.ai.generativelanguage_v1beta.types import * +from google.ai.generativelanguage_v1beta.types import __all__ diff --git a/google/generativeai/responder.py b/google/generativeai/responder.py index 814bf3581..bb85167ad 100644 --- a/google/generativeai/responder.py +++ b/google/generativeai/responder.py @@ -22,9 +22,9 @@ import pydantic -from google.ai import generativelanguage as glm +from google.generativeai import protos -Type = glm.Type +Type = protos.Type TypeOptions = Union[int, str, Type] @@ -186,8 +186,8 @@ def _rename_schema_fields(schema: dict[str, Any]): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -200,7 +200,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -209,7 +209,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -255,16 +255,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -272,8 +272,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -289,15 +289,15 @@ def _make_function_declaration( ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -309,23 +309,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -341,21 +341,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -385,20 +385,20 @@ def __init__(self, tools: Iterable[ToolType]): self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -431,7 +431,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -467,12 +467,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -482,30 +482,31 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - f"Invalid argument type: Could not convert input to `glm.FunctionCallingConfig`. Received type: {type(obj).__name__}.", + "Invalid argument type: Could not convert input to `protos.FunctionCallingConfig`." + f" Received type: {type(obj).__name__}.", obj, ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - f"Invalid argument type: Could not convert input to `glm.ToolConfig`. Received type: {type(obj).__name__}.", - obj, + "Invalid argument type: Could not convert input to `protos.ToolConfig`. " + f"Received type: {type(obj).__name__}.", ) diff --git a/google/generativeai/retriever.py b/google/generativeai/retriever.py index e295bc5b7..53c90140a 100644 --- a/google/generativeai/retriever.py +++ b/google/generativeai/retriever.py @@ -14,12 +14,11 @@ # limitations under the License. from __future__ import annotations -import re -import string -import dataclasses -from typing import Any, AsyncIterable, Iterable, Optional + +from typing import AsyncIterable, Iterable, Optional import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.client import get_default_retriever_client from google.generativeai.client import get_default_retriever_async_client @@ -57,13 +56,13 @@ def create_corpus( client = get_default_retriever_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -86,13 +85,13 @@ async def create_corpus_async( client = get_default_retriever_async_client() if name is None: - corpus = glm.Corpus(display_name=display_name) + corpus = protos.Corpus(display_name=display_name) elif retriever_types.valid_name(name): - corpus = glm.Corpus(name=f"corpora/{name}", display_name=display_name) + corpus = protos.Corpus(name=f"corpora/{name}", display_name=display_name) else: raise ValueError(retriever_types.NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateCorpusRequest(corpus=corpus) + request = protos.CreateCorpusRequest(corpus=corpus) response = await client.create_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -124,7 +123,7 @@ def get_corpus( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -149,7 +148,7 @@ async def get_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.GetCorpusRequest(name=name) + request = protos.GetCorpusRequest(name=name) response = await client.get_corpus(request, **request_options) response = type(response).to_dict(response) idecode_time(response, "create_time") @@ -181,7 +180,7 @@ def delete_corpus( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) client.delete_corpus(request, **request_options) @@ -201,7 +200,7 @@ async def delete_corpus_async( if "/" not in name: name = "corpora/" + name - request = glm.DeleteCorpusRequest(name=name, force=force) + request = protos.DeleteCorpusRequest(name=name, force=force) await client.delete_corpus(request, **request_options) @@ -227,7 +226,7 @@ def list_corpora( if client is None: client = get_default_retriever_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) for corpus in client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") @@ -248,7 +247,7 @@ async def list_corpora_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListCorporaRequest(page_size=page_size) + request = protos.ListCorporaRequest(page_size=page_size) async for corpus in await client.list_corpora(request, **request_options): corpus = type(corpus).to_dict(corpus) idecode_time(corpus, "create_time") diff --git a/google/generativeai/text.py b/google/generativeai/text.py index b8b814754..2a6267661 100644 --- a/google/generativeai/text.py +++ b/google/generativeai/text.py @@ -21,6 +21,8 @@ import google.ai.generativelanguage as glm +from google.generativeai import protos + from google.generativeai.client import get_default_text_client from google.generativeai import string_utils from google.generativeai.types import helper_types @@ -52,23 +54,23 @@ def _batched(iterable: Iterable[T], n: int) -> Iterable[list[T]]: yield batch -def _make_text_prompt(prompt: str | dict[str, str]) -> glm.TextPrompt: +def _make_text_prompt(prompt: str | dict[str, str]) -> protos.TextPrompt: """ - Creates a `glm.TextPrompt` object based on the provided prompt input. + Creates a `protos.TextPrompt` object based on the provided prompt input. Args: prompt: The prompt input, either a string or a dictionary. Returns: - glm.TextPrompt: A TextPrompt object containing the prompt text. + protos.TextPrompt: A TextPrompt object containing the prompt text. Raises: TypeError: If the provided prompt is neither a string nor a dictionary. """ if isinstance(prompt, str): - return glm.TextPrompt(text=prompt) + return protos.TextPrompt(text=prompt) elif isinstance(prompt, dict): - return glm.TextPrompt(prompt) + return protos.TextPrompt(prompt) else: raise TypeError( "Invalid argument type: Expected a string or dictionary for the text prompt." @@ -86,11 +88,11 @@ def _make_generate_text_request( top_k: int | None = None, safety_settings: palm_safety_types.SafetySettingOptions | None = None, stop_sequences: str | Iterable[str] | None = None, -) -> glm.GenerateTextRequest: +) -> protos.GenerateTextRequest: """ - Creates a `glm.GenerateTextRequest` object based on the provided parameters. + Creates a `protos.GenerateTextRequest` object based on the provided parameters. - This function generates a `glm.GenerateTextRequest` object with the specified + This function generates a `protos.GenerateTextRequest` object with the specified parameters. It prepares the input parameters and creates a request that can be used for generating text using the chosen model. @@ -107,7 +109,7 @@ def _make_generate_text_request( or iterable of strings. Defaults to None. Returns: - `glm.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. + `protos.GenerateTextRequest`: A `GenerateTextRequest` object configured with the specified parameters. """ model = model_types.make_model_name(model) prompt = _make_text_prompt(prompt=prompt) @@ -117,7 +119,7 @@ def _make_generate_text_request( if stop_sequences: stop_sequences = list(stop_sequences) - return glm.GenerateTextRequest( + return protos.GenerateTextRequest( model=model, prompt=prompt, temperature=temperature, @@ -216,12 +218,12 @@ def __init__(self, **kwargs): def _generate_response( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, client: glm.TextServiceClient = None, request_options: helper_types.RequestOptionsType | None = None, ) -> Completion: """ - Generates a response using the provided `glm.GenerateTextRequest` and client. + Generates a response using the provided `protos.GenerateTextRequest` and client. Args: request: The text generation request. @@ -267,7 +269,7 @@ def count_text_tokens( client = get_default_text_client() result = client.count_text_tokens( - glm.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), + protos.CountTextTokensRequest(model=base_model, prompt={"text": prompt}), **request_options, ) @@ -322,7 +324,7 @@ def generate_embeddings( client = get_default_text_client() if isinstance(text, str): - embedding_request = glm.EmbedTextRequest(model=model, text=text) + embedding_request = protos.EmbedTextRequest(model=model, text=text) embedding_response = client.embed_text( embedding_request, **request_options, @@ -333,7 +335,7 @@ def generate_embeddings( result = {"embedding": []} for batch in _batched(text, EMBEDDING_MAX_BATCH_SIZE): # TODO(markdaoust): This could use an option for returning an iterator or wait-bar. - embedding_request = glm.BatchEmbedTextRequest(model=model, texts=batch) + embedding_request = protos.BatchEmbedTextRequest(model=model, texts=batch) embedding_response = client.batch_embed_text( embedding_request, **request_options, diff --git a/google/generativeai/types/answer_types.py b/google/generativeai/types/answer_types.py index 18bd11d62..143a578a4 100644 --- a/google/generativeai/types/answer_types.py +++ b/google/generativeai/types/answer_types.py @@ -16,11 +16,11 @@ from typing import Union -import google.ai.generativelanguage as glm +from google.generativeai import protos __all__ = ["Answer"] -FinishReason = glm.Candidate.FinishReason +FinishReason = protos.Candidate.FinishReason FinishReasonOptions = Union[int, str, FinishReason] diff --git a/google/generativeai/types/citation_types.py b/google/generativeai/types/citation_types.py index ae857c35b..9f169703f 100644 --- a/google/generativeai/types/citation_types.py +++ b/google/generativeai/types/citation_types.py @@ -17,7 +17,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -33,10 +33,10 @@ class CitationSourceDict(TypedDict): uri: str | None license: str | None - __doc__ = string_utils.strip_oneof(glm.CitationSource.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationSource.__doc__) class CitationMetadataDict(TypedDict): citation_sources: List[CitationSourceDict | None] - __doc__ = string_utils.strip_oneof(glm.CitationMetadata.__doc__) + __doc__ = string_utils.strip_oneof(protos.CitationMetadata.__doc__) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index 169683608..b8966b005 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -26,7 +26,7 @@ import pydantic from google.generativeai.types import file_types -from google.ai import generativelanguage as glm +from google.generativeai import protos if typing.TYPE_CHECKING: import PIL.Image @@ -80,10 +80,10 @@ def pil_to_blob(img): mime_type = "image/jpeg" bytesio.seek(0) data = bytesio.read() - return glm.Blob(mime_type=mime_type, data=data) + return protos.Blob(mime_type=mime_type, data=data) -def image_to_blob(image) -> glm.Blob: +def image_to_blob(image) -> protos.Blob: if PIL is not None: if isinstance(image, PIL.Image.Image): return pil_to_blob(image) @@ -100,7 +100,7 @@ def image_to_blob(image) -> glm.Blob: if mime_type is None: mime_type = "image/unknown" - return glm.Blob(mime_type=mime_type, data=image.data) + return protos.Blob(mime_type=mime_type, data=image.data) raise TypeError( "Image conversion failed. The input was expected to be of type `Image` " @@ -115,23 +115,23 @@ class BlobDict(TypedDict): data: bytes -def _convert_dict(d: Mapping) -> glm.Content | glm.Part | glm.Blob: +def _convert_dict(d: Mapping) -> protos.Content | protos.Part | protos.Blob: if is_content_dict(d): content = dict(d) if isinstance(parts := content["parts"], str): content["parts"] = [parts] content["parts"] = [to_part(part) for part in content["parts"]] - return glm.Content(content) + return protos.Content(content) elif is_part_dict(d): part = dict(d) if "inline_data" in part: part["inline_data"] = to_blob(part["inline_data"]) if "file_data" in part: part["file_data"] = file_types.to_file_data(part["file_data"]) - return glm.Part(part) + return protos.Part(part) elif is_blob_dict(d): blob = d - return glm.Blob(blob) + return protos.Blob(blob) else: raise KeyError( "Unable to determine the intended type of the `dict`. " @@ -148,17 +148,17 @@ def is_blob_dict(d): if typing.TYPE_CHECKING: BlobType = Union[ - glm.Blob, BlobDict, PIL.Image.Image, IPython.display.Image + protos.Blob, BlobDict, PIL.Image.Image, IPython.display.Image ] # Any for the images else: - BlobType = Union[glm.Blob, BlobDict, Any] + BlobType = Union[protos.Blob, BlobDict, Any] -def to_blob(blob: BlobType) -> glm.Blob: +def to_blob(blob: BlobType) -> protos.Blob: if isinstance(blob, Mapping): blob = _convert_dict(blob) - if isinstance(blob, glm.Blob): + if isinstance(blob, protos.Blob): return blob elif isinstance(blob, IMAGE_TYPES): return image_to_blob(blob) @@ -182,12 +182,12 @@ class PartDict(TypedDict): # When you need a `Part` accept a part object, part-dict, blob or string PartType = Union[ - glm.Part, + protos.Part, PartDict, BlobType, str, - glm.FunctionCall, - glm.FunctionResponse, + protos.FunctionCall, + protos.FunctionResponse, file_types.FileDataType, ] @@ -206,22 +206,22 @@ def to_part(part: PartType): if isinstance(part, Mapping): part = _convert_dict(part) - if isinstance(part, glm.Part): + if isinstance(part, protos.Part): return part elif isinstance(part, str): - return glm.Part(text=part) - elif isinstance(part, glm.FileData): - return glm.Part(file_data=part) - elif isinstance(part, (glm.File, file_types.File)): - return glm.Part(file_data=file_types.to_file_data(part)) - elif isinstance(part, glm.FunctionCall): - return glm.Part(function_call=part) - elif isinstance(part, glm.FunctionResponse): - return glm.Part(function_response=part) + return protos.Part(text=part) + elif isinstance(part, protos.FileData): + return protos.Part(file_data=part) + elif isinstance(part, (protos.File, file_types.File)): + return protos.Part(file_data=file_types.to_file_data(part)) + elif isinstance(part, protos.FunctionCall): + return protos.Part(function_call=part) + elif isinstance(part, protos.FunctionResponse): + return protos.Part(function_response=part) else: # Maybe it can be turned into a blob? - return glm.Part(inline_data=to_blob(part)) + return protos.Part(inline_data=to_blob(part)) class ContentDict(TypedDict): @@ -235,10 +235,10 @@ def is_content_dict(d): # When you need a message accept a `Content` object or dict, a list of parts, # or a single part -ContentType = Union[glm.Content, ContentDict, Iterable[PartType], PartType] +ContentType = Union[protos.Content, ContentDict, Iterable[PartType], PartType] # For generate_content, we're not guessing roles for [[parts],[parts],[parts]] yet. -StrictContentType = Union[glm.Content, ContentDict] +StrictContentType = Union[protos.Content, ContentDict] def to_content(content: ContentType): @@ -250,24 +250,24 @@ def to_content(content: ContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content elif isinstance(content, Iterable) and not isinstance(content, str): - return glm.Content(parts=[to_part(part) for part in content]) + return protos.Content(parts=[to_part(part) for part in content]) else: # Maybe this is a Part? - return glm.Content(parts=[to_part(content)]) + return protos.Content(parts=[to_part(content)]) def strict_to_content(content: StrictContentType): if isinstance(content, Mapping): content = _convert_dict(content) - if isinstance(content, glm.Content): + if isinstance(content, protos.Content): return content else: raise TypeError( - "Invalid input type. Expected a `glm.Content` or a `dict` with a 'parts' key.\n" + "Invalid input type. Expected a `protos.Content` or a `dict` with a 'parts' key.\n" f"However, received an object of type: {type(content)}.\n" f"Object Value: {content}" ) @@ -276,7 +276,7 @@ def strict_to_content(content: StrictContentType): ContentsType = Union[ContentType, Iterable[StrictContentType], None] -def to_contents(contents: ContentsType) -> list[glm.Content]: +def to_contents(contents: ContentsType) -> list[protos.Content]: if contents is None: return [] @@ -509,8 +509,8 @@ def _rename_schema_fields(schema): class FunctionDeclaration: def __init__(self, *, name: str, description: str, parameters: dict[str, Any] | None = None): - """A class wrapping a `glm.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" - self._proto = glm.FunctionDeclaration( + """A class wrapping a `protos.FunctionDeclaration`, describes a function for `genai.GenerativeModel`'s `tools`.""" + self._proto = protos.FunctionDeclaration( name=name, description=description, parameters=_rename_schema_fields(parameters) ) @@ -523,7 +523,7 @@ def description(self) -> str: return self._proto.description @property - def parameters(self) -> glm.Schema: + def parameters(self) -> protos.Schema: return self._proto.parameters @classmethod @@ -532,7 +532,7 @@ def from_proto(cls, proto) -> FunctionDeclaration: self._proto = proto return self - def to_proto(self) -> glm.FunctionDeclaration: + def to_proto(self) -> protos.FunctionDeclaration: return self._proto @staticmethod @@ -578,16 +578,16 @@ def __init__( super().__init__(name=name, description=description, parameters=parameters) self.function = function - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse: result = self.function(**fc.args) if not isinstance(result, dict): result = {"result": result} - return glm.FunctionResponse(name=fc.name, response=result) + return protos.FunctionResponse(name=fc.name, response=result) FunctionDeclarationType = Union[ FunctionDeclaration, - glm.FunctionDeclaration, + protos.FunctionDeclaration, dict[str, Any], Callable[..., Any], ] @@ -595,8 +595,8 @@ def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse: def _make_function_declaration( fun: FunctionDeclarationType, -) -> FunctionDeclaration | glm.FunctionDeclaration: - if isinstance(fun, (FunctionDeclaration, glm.FunctionDeclaration)): +) -> FunctionDeclaration | protos.FunctionDeclaration: + if isinstance(fun, (FunctionDeclaration, protos.FunctionDeclaration)): return fun elif isinstance(fun, dict): if "function" in fun: @@ -613,15 +613,15 @@ def _make_function_declaration( ) -def _encode_fd(fd: FunctionDeclaration | glm.FunctionDeclaration) -> glm.FunctionDeclaration: - if isinstance(fd, glm.FunctionDeclaration): +def _encode_fd(fd: FunctionDeclaration | protos.FunctionDeclaration) -> protos.FunctionDeclaration: + if isinstance(fd, protos.FunctionDeclaration): return fd return fd.to_proto() class Tool: - """A wrapper for `glm.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" + """A wrapper for `protos.Tool`, Contains a collection of related `FunctionDeclaration` objects.""" def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): # The main path doesn't use this but is seems useful. @@ -633,23 +633,23 @@ def __init__(self, function_declarations: Iterable[FunctionDeclarationType]): raise ValueError("") self._index[fd.name] = fd - self._proto = glm.Tool( + self._proto = protos.Tool( function_declarations=[_encode_fd(fd) for fd in self._function_declarations] ) @property - def function_declarations(self) -> list[FunctionDeclaration | glm.FunctionDeclaration]: + def function_declarations(self) -> list[FunctionDeclaration | protos.FunctionDeclaration]: return self._function_declarations def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.FunctionResponse | None: + def __call__(self, fc: protos.FunctionCall) -> protos.FunctionResponse | None: declaration = self[fc] if not callable(declaration): return None @@ -665,21 +665,21 @@ class ToolDict(TypedDict): ToolType = Union[ - Tool, glm.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType + Tool, protos.Tool, ToolDict, Iterable[FunctionDeclarationType], FunctionDeclarationType ] def _make_tool(tool: ToolType) -> Tool: if isinstance(tool, Tool): return tool - elif isinstance(tool, glm.Tool): + elif isinstance(tool, protos.Tool): return Tool(function_declarations=tool.function_declarations) elif isinstance(tool, dict): if "function_declarations" in tool: return Tool(**tool) else: fd = tool - return Tool(function_declarations=[glm.FunctionDeclaration(**fd)]) + return Tool(function_declarations=[protos.FunctionDeclaration(**fd)]) elif isinstance(tool, Iterable): return Tool(function_declarations=tool) else: @@ -711,20 +711,20 @@ def __init__(self, tools: Iterable[ToolType]): self._index[declaration.name] = declaration def __getitem__( - self, name: str | glm.FunctionCall - ) -> FunctionDeclaration | glm.FunctionDeclaration: + self, name: str | protos.FunctionCall + ) -> FunctionDeclaration | protos.FunctionDeclaration: if not isinstance(name, str): name = name.name return self._index[name] - def __call__(self, fc: glm.FunctionCall) -> glm.Part | None: + def __call__(self, fc: protos.FunctionCall) -> protos.Part | None: declaration = self[fc] if not callable(declaration): return None response = declaration(fc) - return glm.Part(function_response=response) + return protos.Part(function_response=response) def to_proto(self): return [tool.to_proto() for tool in self._tools] @@ -757,7 +757,7 @@ def to_function_library(lib: FunctionLibraryType | None) -> FunctionLibrary | No return FunctionLibrary(tools=lib) -FunctionCallingMode = glm.FunctionCallingConfig.Mode +FunctionCallingMode = protos.FunctionCallingConfig.Mode # fmt: off _FUNCTION_CALLING_MODE = { @@ -793,12 +793,12 @@ class FunctionCallingConfigDict(TypedDict): FunctionCallingConfigType = Union[ - FunctionCallingModeType, FunctionCallingConfigDict, glm.FunctionCallingConfig + FunctionCallingModeType, FunctionCallingConfigDict, protos.FunctionCallingConfig ] -def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCallingConfig: - if isinstance(obj, glm.FunctionCallingConfig): +def to_function_calling_config(obj: FunctionCallingConfigType) -> protos.FunctionCallingConfig: + if isinstance(obj, protos.FunctionCallingConfig): return obj elif isinstance(obj, (FunctionCallingMode, str, int)): obj = {"mode": to_function_calling_mode(obj)} @@ -808,32 +808,32 @@ def to_function_calling_config(obj: FunctionCallingConfigType) -> glm.FunctionCa obj["mode"] = to_function_calling_mode(mode) else: raise TypeError( - "Invalid input type. Failed to convert input to `glm.FunctionCallingConfig`.\n" + "Invalid input type. Failed to convert input to `protos.FunctionCallingConfig`.\n" f"Received an object of type: {type(obj)}.\n" f"Object Value: {obj}" ) - return glm.FunctionCallingConfig(obj) + return protos.FunctionCallingConfig(obj) class ToolConfigDict: function_calling_config: FunctionCallingConfigType -ToolConfigType = Union[ToolConfigDict, glm.ToolConfig] +ToolConfigType = Union[ToolConfigDict, protos.ToolConfig] -def to_tool_config(obj: ToolConfigType) -> glm.ToolConfig: - if isinstance(obj, glm.ToolConfig): +def to_tool_config(obj: ToolConfigType) -> protos.ToolConfig: + if isinstance(obj, protos.ToolConfig): return obj elif isinstance(obj, dict): fcc = obj.pop("function_calling_config") fcc = to_function_calling_config(fcc) obj["function_calling_config"] = fcc - return glm.ToolConfig(**obj) + return protos.ToolConfig(**obj) else: raise TypeError( - "Invalid input type. Failed to convert input to `glm.ToolConfig`.\n" + "Invalid input type. Failed to convert input to `protos.ToolConfig`.\n" f"Received an object of type: {type(obj)}.\n" f"Object Value: {obj}" ) diff --git a/google/generativeai/types/discuss_types.py b/google/generativeai/types/discuss_types.py index fa777d1d1..a538da65c 100644 --- a/google/generativeai/types/discuss_types.py +++ b/google/generativeai/types/discuss_types.py @@ -19,7 +19,7 @@ from typing import Any, Dict, Union, Iterable, Optional, Tuple, List from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import palm_safety_types @@ -46,15 +46,15 @@ class TokenCount(TypedDict): class MessageDict(TypedDict): - """A dict representation of a `glm.Message`.""" + """A dict representation of a `protos.Message`.""" author: str content: str citation_metadata: Optional[citation_types.CitationMetadataDict] -MessageOptions = Union[str, MessageDict, glm.Message] -MESSAGE_OPTIONS = (str, dict, glm.Message) +MessageOptions = Union[str, MessageDict, protos.Message] +MESSAGE_OPTIONS = (str, dict, protos.Message) MessagesOptions = Union[ MessageOptions, @@ -64,7 +64,7 @@ class MessageDict(TypedDict): class ExampleDict(TypedDict): - """A dict representation of a `glm.Example`.""" + """A dict representation of a `protos.Example`.""" input: MessageOptions output: MessageOptions @@ -74,14 +74,14 @@ class ExampleDict(TypedDict): Tuple[MessageOptions, MessageOptions], Iterable[MessageOptions], ExampleDict, - glm.Example, + protos.Example, ] -EXAMPLE_OPTIONS = (glm.Example, dict, Iterable) +EXAMPLE_OPTIONS = (protos.Example, dict, Iterable) ExamplesOptions = Union[ExampleOptions, Iterable[ExampleOptions]] class MessagePromptDict(TypedDict, total=False): - """A dict representation of a `glm.MessagePrompt`.""" + """A dict representation of a `protos.MessagePrompt`.""" context: str examples: ExamplesOptions @@ -90,16 +90,16 @@ class MessagePromptDict(TypedDict, total=False): MessagePromptOptions = Union[ str, - glm.Message, - Iterable[Union[str, glm.Message]], + protos.Message, + Iterable[Union[str, protos.Message]], MessagePromptDict, - glm.MessagePrompt, + protos.MessagePrompt, ] MESSAGE_PROMPT_KEYS = {"context", "examples", "messages"} class ResponseDict(TypedDict): - """A dict representation of a `glm.GenerateMessageResponse`.""" + """A dict representation of a `protos.GenerateMessageResponse`.""" messages: List[MessageDict] candidates: List[MessageDict] diff --git a/google/generativeai/types/file_types.py b/google/generativeai/types/file_types.py index 0fdf05322..ef251e296 100644 --- a/google/generativeai/types/file_types.py +++ b/google/generativeai/types/file_types.py @@ -21,16 +21,16 @@ from google.rpc.status_pb2 import Status from google.generativeai.client import get_default_file_client -import google.ai.generativelanguage as glm +from google.generativeai import protos class File: - def __init__(self, proto: glm.File | File | dict): + def __init__(self, proto: protos.File | File | dict): if isinstance(proto, File): proto = proto.to_proto() - self._proto = glm.File(proto) + self._proto = protos.File(proto) - def to_proto(self) -> glm.File: + def to_proto(self) -> protos.File: return self._proto @property @@ -70,11 +70,11 @@ def uri(self) -> str: return self._proto.uri @property - def state(self) -> glm.File.State: + def state(self) -> protos.File.State: return self._proto.state @property - def video_metadata(self) -> glm.VideoMetadata: + def video_metadata(self) -> protos.VideoMetadata: return self._proto.video_metadata @property @@ -91,26 +91,26 @@ class FileDataDict(TypedDict): file_uri: str -FileDataType = Union[FileDataDict, glm.FileData, glm.File, File] +FileDataType = Union[FileDataDict, protos.FileData, protos.File, File] def to_file_data(file_data: FileDataType): if isinstance(file_data, dict): if "file_uri" in file_data: - file_data = glm.FileData(file_data) + file_data = protos.FileData(file_data) else: - file_data = glm.File(file_data) + file_data = protos.File(file_data) if isinstance(file_data, File): file_data = file_data.to_proto() - if isinstance(file_data, glm.File): - file_data = glm.FileData( + if isinstance(file_data, protos.File): + file_data = protos.FileData( mime_type=file_data.mime_type, file_uri=file_data.uri, ) - if isinstance(file_data, glm.FileData): + if isinstance(file_data, protos.FileData): return file_data else: raise TypeError( diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 8d39f76c7..20686a156 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -30,7 +30,7 @@ import google.protobuf.json_format import google.api_core.exceptions -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils from google.generativeai.types import content_types from google.generativeai.responder import _rename_schema_fields @@ -85,7 +85,7 @@ class GenerationConfigDict(TypedDict, total=False): max_output_tokens: int temperature: float response_mime_type: str - response_schema: glm.Schema | Mapping[str, Any] # fmt: off + response_schema: protos.Schema | Mapping[str, Any] # fmt: off @dataclasses.dataclass @@ -165,19 +165,19 @@ class GenerationConfig: top_p: float | None = None top_k: int | None = None response_mime_type: str | None = None - response_schema: glm.Schema | Mapping[str, Any] | None = None + response_schema: protos.Schema | Mapping[str, Any] | None = None -GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig] +GenerationConfigType = Union[protos.GenerationConfig, GenerationConfigDict, GenerationConfig] def _normalize_schema(generation_config): - # Convert response_schema to glm.Schema for request + # Convert response_schema to protos.Schema for request response_schema = generation_config.get("response_schema", None) if response_schema is None: return - if isinstance(response_schema, glm.Schema): + if isinstance(response_schema, protos.Schema): return if isinstance(response_schema, type): @@ -191,13 +191,13 @@ def _normalize_schema(generation_config): response_schema = content_types._schema_for_class(response_schema) response_schema = _rename_schema_fields(response_schema) - generation_config["response_schema"] = glm.Schema(response_schema) + generation_config["response_schema"] = protos.Schema(response_schema) def to_generation_config_dict(generation_config: GenerationConfigType): if generation_config is None: return {} - elif isinstance(generation_config, glm.GenerationConfig): + elif isinstance(generation_config, protos.GenerationConfig): schema = generation_config.response_schema generation_config = type(generation_config).to_dict( generation_config @@ -221,14 +221,14 @@ def to_generation_config_dict(generation_config: GenerationConfigType): def _join_citation_metadatas( - citation_metadatas: Iterable[glm.CitationMetadata], + citation_metadatas: Iterable[protos.CitationMetadata], ): citation_metadatas = list(citation_metadatas) return citation_metadatas[-1] def _join_safety_ratings_lists( - safety_ratings_lists: Iterable[list[glm.SafetyRating]], + safety_ratings_lists: Iterable[list[protos.SafetyRating]], ): ratings = {} blocked = collections.defaultdict(list) @@ -243,13 +243,13 @@ def _join_safety_ratings_lists( safety_list = [] for (category, probability), blocked in zip(ratings.items(), blocked.values()): safety_list.append( - glm.SafetyRating(category=category, probability=probability, blocked=blocked) + protos.SafetyRating(category=category, probability=probability, blocked=blocked) ) return safety_list -def _join_contents(contents: Iterable[glm.Content]): +def _join_contents(contents: Iterable[protos.Content]): contents = tuple(contents) roles = [c.role for c in contents if c.role] if roles: @@ -271,22 +271,22 @@ def _join_contents(contents: Iterable[glm.Content]): merged_parts.append(part) continue - merged_part = glm.Part(merged_parts[-1]) + merged_part = protos.Part(merged_parts[-1]) merged_part.text += part.text merged_parts[-1] = merged_part - return glm.Content( + return protos.Content( role=role, parts=merged_parts, ) -def _join_candidates(candidates: Iterable[glm.Candidate]): +def _join_candidates(candidates: Iterable[protos.Candidate]): candidates = tuple(candidates) index = candidates[0].index # These should all be the same. - return glm.Candidate( + return protos.Candidate( index=index, content=_join_contents([c.content for c in candidates]), finish_reason=candidates[-1].finish_reason, @@ -296,7 +296,7 @@ def _join_candidates(candidates: Iterable[glm.Candidate]): ) -def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): +def _join_candidate_lists(candidate_lists: Iterable[list[protos.Candidate]]): # Assuming that is a candidate ends, it is no longer returned in the list of # candidates and that's why candidates have an index candidates = collections.defaultdict(list) @@ -312,15 +312,15 @@ def _join_candidate_lists(candidate_lists: Iterable[list[glm.Candidate]]): def _join_prompt_feedbacks( - prompt_feedbacks: Iterable[glm.GenerateContentResponse.PromptFeedback], + prompt_feedbacks: Iterable[protos.GenerateContentResponse.PromptFeedback], ): # Always return the first prompt feedback. return next(iter(prompt_feedbacks)) -def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]): +def _join_chunks(chunks: Iterable[protos.GenerateContentResponse]): chunks = tuple(chunks) - return glm.GenerateContentResponse( + return protos.GenerateContentResponse( candidates=_join_candidate_lists(c.candidates for c in chunks), prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks), usage_metadata=chunks[-1].usage_metadata, @@ -338,11 +338,11 @@ def __init__( done: bool, iterator: ( None - | Iterable[glm.GenerateContentResponse] - | AsyncIterable[glm.GenerateContentResponse] + | Iterable[protos.GenerateContentResponse] + | AsyncIterable[protos.GenerateContentResponse] ), - result: glm.GenerateContentResponse, - chunks: Iterable[glm.GenerateContentResponse] | None = None, + result: protos.GenerateContentResponse, + chunks: Iterable[protos.GenerateContentResponse] | None = None, ): self._done = done self._iterator = iterator @@ -440,7 +440,7 @@ def __str__(self) -> str: ) json_str = json.dumps(as_dict, indent=2) - _result = f"glm.GenerateContentResponse({json_str})" + _result = f"protos.GenerateContentResponse({json_str})" _result = _result.replace("\n", "\n ") if self._error: @@ -478,7 +478,7 @@ def rewrite_stream_error(): GENERATE_CONTENT_RESPONSE_DOC = """Instances of this class manage the response of the `generate_content` method. These are returned by `GenerativeModel.generate_content` and `ChatSession.send_message`. - This object is based on the low level `glm.GenerateContentResponse` class which just has `prompt_feedback` + This object is based on the low level `protos.GenerateContentResponse` class which just has `prompt_feedback` and `candidates` attributes. This class adds several quick accessors for common use cases. The same object type is returned for both `stream=True/False`. @@ -507,7 +507,7 @@ def rewrite_stream_error(): @string_utils.set_doc(GENERATE_CONTENT_RESPONSE_DOC) class GenerateContentResponse(BaseGenerateContentResponse): @classmethod - def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): + def from_iterator(cls, iterator: Iterable[protos.GenerateContentResponse]): iterator = iter(iterator) with rewrite_stream_error(): response = next(iterator) @@ -519,7 +519,7 @@ def from_iterator(cls, iterator: Iterable[glm.GenerateContentResponse]): ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, @@ -574,7 +574,7 @@ def resolve(self): @string_utils.set_doc(ASYNC_GENERATE_CONTENT_RESPONSE_DOC) class AsyncGenerateContentResponse(BaseGenerateContentResponse): @classmethod - async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentResponse]): + async def from_aiterator(cls, iterator: AsyncIterable[protos.GenerateContentResponse]): iterator = aiter(iterator) # type: ignore with rewrite_stream_error(): response = await anext(iterator) # type: ignore @@ -586,7 +586,7 @@ async def from_aiterator(cls, iterator: AsyncIterable[glm.GenerateContentRespons ) @classmethod - def from_response(cls, response: glm.GenerateContentResponse): + def from_response(cls, response: protos.GenerateContentResponse): return cls( done=True, iterator=None, diff --git a/google/generativeai/types/model_types.py b/google/generativeai/types/model_types.py index 32b3bddae..81a545b30 100644 --- a/google/generativeai/types/model_types.py +++ b/google/generativeai/types/model_types.py @@ -28,7 +28,7 @@ import urllib.request from typing_extensions import TypedDict -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import permission_types from google.generativeai import string_utils @@ -44,7 +44,7 @@ "TunedModelState", ] -TunedModelState = glm.TunedModel.State +TunedModelState = protos.TunedModel.State TunedModelStateOptions = Union[None, str, int, TunedModelState] @@ -91,7 +91,7 @@ def to_tuned_model_state(x: TunedModelStateOptions) -> TunedModelState: @string_utils.prettyprint @dataclasses.dataclass class Model: - """A dataclass representation of a `glm.Model`. + """A dataclass representation of a `protos.Model`. Attributes: name: The resource name of the `Model`. Format: `models/{model}` with a `{model}` naming @@ -140,8 +140,8 @@ def idecode_time(parent: dict["str", Any], name: str): parent[name] = dt -def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedModel: - if isinstance(tuned_model, glm.TunedModel): +def decode_tuned_model(tuned_model: protos.TunedModel | dict["str", Any]) -> TunedModel: + if isinstance(tuned_model, protos.TunedModel): tuned_model = type(tuned_model).to_dict(tuned_model) # pytype: disable=attribute-error tuned_model["state"] = to_tuned_model_state(tuned_model.pop("state", None)) @@ -180,7 +180,7 @@ def decode_tuned_model(tuned_model: glm.TunedModel | dict["str", Any]) -> TunedM @string_utils.prettyprint @dataclasses.dataclass class TunedModel: - """A dataclass representation of a `glm.TunedModel`.""" + """A dataclass representation of a `protos.TunedModel`.""" name: str | None = None source_model: str | None = None @@ -214,13 +214,13 @@ class TuningExampleDict(TypedDict): output: str -TuningExampleOptions = Union[TuningExampleDict, glm.TuningExample, tuple[str, str], list[str]] +TuningExampleOptions = Union[TuningExampleDict, protos.TuningExample, tuple[str, str], list[str]] # TODO(markdaoust): gs:// URLS? File-type argument for files without extension? TuningDataOptions = Union[ pathlib.Path, str, - glm.Dataset, + protos.Dataset, Mapping[str, Iterable[str]], Iterable[TuningExampleOptions], ] @@ -228,8 +228,8 @@ class TuningExampleDict(TypedDict): def encode_tuning_data( data: TuningDataOptions, input_key="text_input", output_key="output" -) -> glm.Dataset: - if isinstance(data, glm.Dataset): +) -> protos.Dataset: + if isinstance(data, protos.Dataset): return data if isinstance(data, str): @@ -301,8 +301,8 @@ def _convert_dict(data, input_key, output_key): ) for i, o in zip(inputs, outputs): - new_data.append(glm.TuningExample({"text_input": str(i), "output": str(o)})) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + new_data.append(protos.TuningExample({"text_input": str(i), "output": str(o)})) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def _convert_iterable(data, input_key, output_key): @@ -310,17 +310,17 @@ def _convert_iterable(data, input_key, output_key): for example in data: example = encode_tuning_example(example, input_key, output_key) new_data.append(example) - return glm.Dataset(examples=glm.TuningExamples(examples=new_data)) + return protos.Dataset(examples=protos.TuningExamples(examples=new_data)) def encode_tuning_example(example: TuningExampleOptions, input_key, output_key): - if isinstance(example, glm.TuningExample): + if isinstance(example, protos.TuningExample): return example elif isinstance(example, (tuple, list)): a, b = example - example = glm.TuningExample(text_input=a, output=b) + example = protos.TuningExample(text_input=a, output=b) else: # dict - example = glm.TuningExample(text_input=example[input_key], output=example[output_key]) + example = protos.TuningExample(text_input=example[input_key], output=example[output_key]) return example @@ -341,14 +341,14 @@ class Hyperparameters: learning_rate: float = 0.0 -BaseModelNameOptions = Union[str, Model, glm.Model] -TunedModelNameOptions = Union[str, TunedModel, glm.TunedModel] -AnyModelNameOptions = Union[str, Model, glm.Model, TunedModel, glm.TunedModel] +BaseModelNameOptions = Union[str, Model, protos.Model] +TunedModelNameOptions = Union[str, TunedModel, protos.TunedModel] +AnyModelNameOptions = Union[str, Model, protos.Model, TunedModel, protos.TunedModel] ModelNameOptions = AnyModelNameOptions def make_model_name(name: AnyModelNameOptions): - if isinstance(name, (Model, glm.Model, TunedModel, glm.TunedModel)): + if isinstance(name, (Model, protos.Model, TunedModel, protos.TunedModel)): name = name.name # pytype: disable=attribute-error elif isinstance(name, str): name = name @@ -372,7 +372,7 @@ def make_model_name(name: AnyModelNameOptions): @string_utils.prettyprint @dataclasses.dataclass class TokenCount: - """A dataclass representation of a `glm.TokenCountResponse`. + """A dataclass representation of a `protos.TokenCountResponse`. Attributes: token_count: The number of tokens returned by the model's tokenizer for the `input_text`. diff --git a/google/generativeai/types/palm_safety_types.py b/google/generativeai/types/palm_safety_types.py index 9fb88cd67..0ab85e1b2 100644 --- a/google/generativeai/types/palm_safety_types.py +++ b/google/generativeai/types/palm_safety_types.py @@ -23,7 +23,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -39,9 +39,9 @@ ] # These are basic python enums, it's okay to expose them -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason class HarmCategory: @@ -49,70 +49,70 @@ class HarmCategory: Harm Categories supported by the palm-family models """ - HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value - HARM_CATEGORY_DEROGATORY = glm.HarmCategory.HARM_CATEGORY_DEROGATORY.value - HARM_CATEGORY_TOXICITY = glm.HarmCategory.HARM_CATEGORY_TOXICITY.value - HARM_CATEGORY_VIOLENCE = glm.HarmCategory.HARM_CATEGORY_VIOLENCE.value - HARM_CATEGORY_SEXUAL = glm.HarmCategory.HARM_CATEGORY_SEXUAL.value - HARM_CATEGORY_MEDICAL = glm.HarmCategory.HARM_CATEGORY_MEDICAL.value - HARM_CATEGORY_DANGEROUS = glm.HarmCategory.HARM_CATEGORY_DANGEROUS.value + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_DEROGATORY = protos.HarmCategory.HARM_CATEGORY_DEROGATORY.value + HARM_CATEGORY_TOXICITY = protos.HarmCategory.HARM_CATEGORY_TOXICITY.value + HARM_CATEGORY_VIOLENCE = protos.HarmCategory.HARM_CATEGORY_VIOLENCE.value + HARM_CATEGORY_SEXUAL = protos.HarmCategory.HARM_CATEGORY_SEXUAL.value + HARM_CATEGORY_MEDICAL = protos.HarmCategory.HARM_CATEGORY_MEDICAL.value + HARM_CATEGORY_DANGEROUS = protos.HarmCategory.HARM_CATEGORY_DANGEROUS.value HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { - glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - - glm.HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - HarmCategory.HARM_CATEGORY_DEROGATORY: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - 1: glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - "harm_category_derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - "derogatory": glm.HarmCategory.HARM_CATEGORY_DEROGATORY, - - glm.HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - HarmCategory.HARM_CATEGORY_TOXICITY: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - 2: glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "harm_category_toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "toxicity": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - "toxic": glm.HarmCategory.HARM_CATEGORY_TOXICITY, - - glm.HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - HarmCategory.HARM_CATEGORY_VIOLENCE: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - 3: glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "harm_category_violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "violence": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - "violent": glm.HarmCategory.HARM_CATEGORY_VIOLENCE, - - glm.HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - HarmCategory.HARM_CATEGORY_SEXUAL: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - 4: glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - "sex": glm.HarmCategory.HARM_CATEGORY_SEXUAL, - - glm.HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - HarmCategory.HARM_CATEGORY_MEDICAL: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - 5: glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "harm_category_medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "medical": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - "med": glm.HarmCategory.HARM_CATEGORY_MEDICAL, - - glm.HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - HarmCategory.HARM_CATEGORY_DANGEROUS: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - 6: glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + + protos.HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + HarmCategory.HARM_CATEGORY_DEROGATORY: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + 1: protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "harm_category_derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + "derogatory": protos.HarmCategory.HARM_CATEGORY_DEROGATORY, + + protos.HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + HarmCategory.HARM_CATEGORY_TOXICITY: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + 2: protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "harm_category_toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxicity": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + "toxic": protos.HarmCategory.HARM_CATEGORY_TOXICITY, + + protos.HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + HarmCategory.HARM_CATEGORY_VIOLENCE: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + 3: protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "harm_category_violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violence": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + "violent": protos.HarmCategory.HARM_CATEGORY_VIOLENCE, + + protos.HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + HarmCategory.HARM_CATEGORY_SEXUAL: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + 4: protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUAL, + + protos.HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + HarmCategory.HARM_CATEGORY_MEDICAL: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + 5: protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "harm_category_medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "medical": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + "med": protos.HarmCategory.HARM_CATEGORY_MEDICAL, + + protos.HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + HarmCategory.HARM_CATEGORY_DANGEROUS: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + 6: protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, } # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() return _HARM_CATEGORIES[x] @@ -161,7 +161,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -177,15 +177,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": glm.HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -198,10 +198,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -251,7 +251,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": glm.HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -260,7 +260,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/google/generativeai/types/permission_types.py b/google/generativeai/types/permission_types.py index fde2ddacc..1df831db0 100644 --- a/google/generativeai/types/permission_types.py +++ b/google/generativeai/types/permission_types.py @@ -19,6 +19,7 @@ import re import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 @@ -28,8 +29,8 @@ from google.generativeai import string_utils -GranteeType = glm.Permission.GranteeType -Role = glm.Permission.Role +GranteeType = protos.Permission.GranteeType +Role = protos.Permission.Role GranteeTypeOptions = Union[str, int, GranteeType] RoleOptions = Union[str, int, Role] @@ -108,7 +109,7 @@ def delete( """ if client is None: client = get_default_permission_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + delete_request = protos.DeletePermissionRequest(name=self.name) client.delete_permission(request=delete_request) async def delete_async( @@ -120,7 +121,7 @@ async def delete_async( """ if client is None: client = get_default_permission_async_client() - delete_request = glm.DeletePermissionRequest(name=self.name) + delete_request = protos.DeletePermissionRequest(name=self.name) await client.delete_permission(request=delete_request) # TODO (magashe): Add a method to validate update value. As of now only `role` is supported as a mask path @@ -161,7 +162,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) client.update_permission(request=update_request) @@ -191,14 +192,14 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - update_request = glm.UpdatePermissionRequest( + update_request = protos.UpdatePermissionRequest( permission=self._to_proto(), update_mask=field_mask ) await client.update_permission(request=update_request) return self - def _to_proto(self) -> glm.Permission: - return glm.Permission( + def _to_proto(self) -> protos.Permission: + return protos.Permission( name=self.name, role=self.role, grantee_type=self.grantee_type, @@ -225,7 +226,7 @@ def get( """ if client is None: client = get_default_permission_client() - get_perm_request = glm.GetPermissionRequest(name=name) + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -241,7 +242,7 @@ async def get_async( """ if client is None: client = get_default_permission_async_client() - get_perm_request = glm.GetPermissionRequest(name=name) + get_perm_request = protos.GetPermissionRequest(name=name) get_perm_response = await client.get_permission(request=get_perm_request) get_perm_response = type(get_perm_response).to_dict(get_perm_response) return cls(**get_perm_response) @@ -263,7 +264,7 @@ def _make_create_permission_request( role: RoleOptions, grantee_type: Optional[GranteeTypeOptions] = None, email_address: Optional[str] = None, - ) -> glm.CreatePermissionRequest: + ) -> protos.CreatePermissionRequest: role = to_role(role) if grantee_type: @@ -278,12 +279,12 @@ def _make_create_permission_request( f"Invalid operation: An 'email_address' must be provided when 'grantee_type' is not set to 'EVERYONE'. Currently, 'grantee_type' is set to '{grantee_type}' and 'email_address' is '{email_address if email_address else 'not provided'}'." ) - permission = glm.Permission( + permission = protos.Permission( role=role, grantee_type=grantee_type, email_address=email_address, ) - return glm.CreatePermissionRequest( + return protos.CreatePermissionRequest( parent=self.parent, permission=permission, ) @@ -359,7 +360,7 @@ def list( if client is None: client = get_default_permission_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) for permission in client.list_permissions(request): @@ -377,7 +378,7 @@ async def list_async( if client is None: client = get_default_permission_async_client() - request = glm.ListPermissionsRequest( + request = protos.ListPermissionsRequest( parent=self.parent, page_size=page_size # pytype: disable=attribute-error ) async for permission in await client.list_permissions(request): @@ -400,7 +401,7 @@ def transfer_ownership( raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: client = get_default_permission_client() - transfer_request = glm.TransferOwnershipRequest( + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return client.transfer_ownership(request=transfer_request) @@ -415,7 +416,7 @@ async def transfer_ownership_async( raise NotImplementedError("Can'/t transfer_ownership for a Corpus") if client is None: client = get_default_permission_async_client() - transfer_request = glm.TransferOwnershipRequest( + transfer_request = protos.TransferOwnershipRequest( name=self.parent, email_address=email_address # pytype: disable=attribute-error ) return await client.transfer_ownership(request=transfer_request) diff --git a/google/generativeai/types/retriever_types.py b/google/generativeai/types/retriever_types.py index 294e0b64c..9931ee58d 100644 --- a/google/generativeai/types/retriever_types.py +++ b/google/generativeai/types/retriever_types.py @@ -22,6 +22,7 @@ from typing_extensions import deprecated # type: ignore import google.ai.generativelanguage as glm +from google.generativeai import protos from google.protobuf import field_mask_pb2 from google.generativeai.client import get_default_retriever_client @@ -44,14 +45,14 @@ def valid_name(name): return re.match(_VALID_NAME, name) and len(name) < 40 -Operator = glm.Condition.Operator -State = glm.Chunk.State +Operator = protos.Condition.Operator +State = protos.Chunk.State OperatorOptions = Union[str, int, Operator] StateOptions = Union[str, int, State] ChunkOptions = Union[ - glm.Chunk, + protos.Chunk, str, tuple[str, str], tuple[str, str, Any], @@ -59,17 +60,17 @@ def valid_name(name): ] # fmt: no BatchCreateChunkOptions = Union[ - glm.BatchCreateChunksRequest, + protos.BatchCreateChunksRequest, Mapping[str, str], Mapping[str, tuple[str, str]], Iterable[ChunkOptions], ] # fmt: no -UpdateChunkOptions = Union[glm.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] +UpdateChunkOptions = Union[protos.UpdateChunkRequest, Mapping[str, Any], tuple[str, Any]] -BatchUpdateChunksOptions = Union[glm.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] +BatchUpdateChunksOptions = Union[protos.BatchUpdateChunksRequest, Iterable[UpdateChunkOptions]] -BatchDeleteChunkOptions = Union[list[glm.DeleteChunkRequest], Iterable[str]] +BatchDeleteChunkOptions = Union[list[protos.DeleteChunkRequest], Iterable[str]] _OPERATOR: dict[OperatorOptions, Operator] = { Operator.OPERATOR_UNSPECIFIED: Operator.OPERATOR_UNSPECIFIED, @@ -163,10 +164,10 @@ def _to_proto(self): ) kwargs["operation"] = c.operation - condition = glm.Condition(**kwargs) + condition = protos.Condition(**kwargs) conditions.append(condition) - return glm.MetadataFilter(key=self.key, conditions=conditions) + return protos.MetadataFilter(key=self.key, conditions=conditions) @string_utils.prettyprint @@ -188,17 +189,17 @@ def _to_proto(self): kwargs["string_value"] = self.value elif isinstance(self.value, Iterable): if isinstance(self.value, Mapping): - # If already converted to a glm.StringList, get the values + # If already converted to a protos.StringList, get the values kwargs["string_list_value"] = self.value else: - kwargs["string_list_value"] = glm.StringList(values=self.value) + kwargs["string_list_value"] = protos.StringList(values=self.value) elif isinstance(self.value, (int, float)): kwargs["numeric_value"] = float(self.value) else: raise ValueError( f"Invalid value type: The value for a custom_metadata specification must be either a list of string values, a string, or an integer/float. Received: '{self.value}' of type {type(self.value).__name__}." ) - return glm.CustomMetadata(key=self.key, **kwargs) + return protos.CustomMetadata(key=self.key, **kwargs) @classmethod def _from_dict(cls, cm): @@ -216,14 +217,14 @@ def _to_dict(self): return type(proto).to_dict(proto) -CustomMetadataOptions = Union[CustomMetadata, glm.CustomMetadata, dict] +CustomMetadataOptions = Union[CustomMetadata, protos.CustomMetadata, dict] def make_custom_metadata(cm: CustomMetadataOptions) -> CustomMetadata: if isinstance(cm, CustomMetadata): return cm - if isinstance(cm, glm.CustomMetadata): + if isinstance(cm, protos.CustomMetadata): cm = type(cm).to_dict(cm) if isinstance(cm, dict): @@ -293,9 +294,9 @@ def create_document( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -303,7 +304,7 @@ def create_document( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = client.create_document(request, **request_options) return decode_document(response) @@ -329,9 +330,9 @@ async def create_document_async( c_data.append(cm._to_proto()) if name is None: - document = glm.Document(display_name=display_name, custom_metadata=c_data) + document = protos.Document(display_name=display_name, custom_metadata=c_data) elif valid_name(name): - document = glm.Document( + document = protos.Document( name=f"{self.name}/documents/{name}", display_name=display_name, custom_metadata=c_data, @@ -339,7 +340,7 @@ async def create_document_async( else: raise ValueError(NAME_ERROR_MSG.format(length=len(name), name=name)) - request = glm.CreateDocumentRequest(parent=self.name, document=document) + request = protos.CreateDocumentRequest(parent=self.name, document=document) response = await client.create_document(request, **request_options) return decode_document(response) @@ -368,7 +369,7 @@ def get_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = client.get_document(request, **request_options) return decode_document(response) @@ -388,7 +389,7 @@ async def get_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.GetDocumentRequest(name=name) + request = protos.GetDocumentRequest(name=name) response = await client.get_document(request, **request_options) return decode_document(response) @@ -434,7 +435,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) client.update_corpus(request, **request_options) return self @@ -465,7 +466,7 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) + request = protos.UpdateCorpusRequest(corpus=self.to_dict(), update_mask=field_mask) await client.update_corpus(request, **request_options) return self @@ -506,7 +507,7 @@ def query( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -551,7 +552,7 @@ async def query_async( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryCorpusRequest( + request = protos.QueryCorpusRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -594,7 +595,7 @@ def delete_document( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) client.delete_document(request, **request_options) async def delete_document_async( @@ -614,7 +615,7 @@ async def delete_document_async( if "/" not in name: name = f"{self.name}/documents/{name}" - request = glm.DeleteDocumentRequest(name=name, force=bool(force)) + request = protos.DeleteDocumentRequest(name=name, force=bool(force)) await client.delete_document(request, **request_options) def list_documents( @@ -640,7 +641,7 @@ def list_documents( if client is None: client = get_default_retriever_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -660,7 +661,7 @@ async def list_documents_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListDocumentsRequest( + request = protos.ListDocumentsRequest( parent=self.name, page_size=page_size, ) @@ -792,15 +793,17 @@ def create_chunk( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = client.create_chunk(request, **request_options) return decode_chunk(response) @@ -834,24 +837,26 @@ async def create_chunk_async( chunk_name = name if isinstance(data, str): - chunk = glm.Chunk(name=chunk_name, data={"string_value": data}, custom_metadata=c_data) + chunk = protos.Chunk( + name=chunk_name, data={"string_value": data}, custom_metadata=c_data + ) else: - chunk = glm.Chunk( + chunk = protos.Chunk( name=chunk_name, data={"string_value": data.string_value}, custom_metadata=c_data, ) - request = glm.CreateChunkRequest(parent=self.name, chunk=chunk) + request = protos.CreateChunkRequest(parent=self.name, chunk=chunk) response = await client.create_chunk(request, **request_options) return decode_chunk(response) - def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: + def _make_chunk(self, chunk: ChunkOptions) -> protos.Chunk: # del self - if isinstance(chunk, glm.Chunk): - return glm.Chunk(chunk) + if isinstance(chunk, protos.Chunk): + return protos.Chunk(chunk) elif isinstance(chunk, str): - return glm.Chunk(data={"string_value": chunk}) + return protos.Chunk(data={"string_value": chunk}) elif isinstance(chunk, tuple): if len(chunk) == 2: name, data = chunk # pytype: disable=bad-unpacking @@ -864,7 +869,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: f"value: {chunk}" ) - return glm.Chunk( + return protos.Chunk( name=name, data={"string_value": data}, custom_metadata=custom_metadata, @@ -873,7 +878,7 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: if isinstance(chunk["data"], str): chunk = dict(chunk) chunk["data"] = {"string_value": chunk["data"]} - return glm.Chunk(chunk) + return protos.Chunk(chunk) else: raise TypeError( f"Invalid input: Could not convert instance of type '{type(chunk).__name__}' to a chunk. Received value: '{chunk}'." @@ -881,8 +886,8 @@ def _make_chunk(self, chunk: ChunkOptions) -> glm.Chunk: def _make_batch_create_chunk_request( self, chunks: BatchCreateChunkOptions - ) -> glm.BatchCreateChunksRequest: - if isinstance(chunks, glm.BatchCreateChunksRequest): + ) -> protos.BatchCreateChunksRequest: + if isinstance(chunks, protos.BatchCreateChunksRequest): return chunks if isinstance(chunks, Mapping): @@ -901,9 +906,9 @@ def _make_batch_create_chunk_request( chunk.name = f"{self.name}/chunks/{chunk.name}" - requests.append(glm.CreateChunkRequest(parent=self.name, chunk=chunk)) + requests.append(protos.CreateChunkRequest(parent=self.name, chunk=chunk)) - return glm.BatchCreateChunksRequest(parent=self.name, requests=requests) + return protos.BatchCreateChunksRequest(parent=self.name, requests=requests) def batch_create_chunks( self, @@ -973,7 +978,7 @@ def get_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = client.get_chunk(request, **request_options) return decode_chunk(response) @@ -993,7 +998,7 @@ async def get_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.GetChunkRequest(name=name) + request = protos.GetChunkRequest(name=name) response = await client.get_chunk(request, **request_options) return decode_chunk(response) @@ -1019,7 +1024,7 @@ def list_chunks( if client is None: client = get_default_retriever_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) for chunk in client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1036,7 +1041,7 @@ async def list_chunks_async( if client is None: client = get_default_retriever_async_client() - request = glm.ListChunksRequest(parent=self.name, page_size=page_size) + request = protos.ListChunksRequest(parent=self.name, page_size=page_size) async for chunk in await client.list_chunks(request, **request_options): yield decode_chunk(chunk) @@ -1076,7 +1081,7 @@ def query( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1121,7 +1126,7 @@ async def query_async( for mf in metadata_filters: m_f_.append(mf._to_proto()) - request = glm.QueryDocumentRequest( + request = protos.QueryDocumentRequest( name=self.name, query=query, metadata_filters=m_f_, @@ -1181,7 +1186,7 @@ def update( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) client.update_document(request, **request_options) return self @@ -1211,7 +1216,7 @@ async def update_async( for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) + request = protos.UpdateDocumentRequest(document=self.to_dict(), update_mask=field_mask) await client.update_document(request, **request_options) return self @@ -1237,7 +1242,7 @@ def batch_update_chunks( if client is None: client = get_default_retriever_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1270,15 +1275,17 @@ def batch_update_chunks( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1304,9 +1311,10 @@ def batch_update_chunks( ) else: raise TypeError( - "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests'," + " dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1324,7 +1332,7 @@ async def batch_update_chunks_async( if client is None: client = get_default_retriever_async_client() - if isinstance(chunks, glm.BatchUpdateChunksRequest): + if isinstance(chunks, protos.BatchUpdateChunksRequest): response = client.batch_update_chunks(chunks) response = type(response).to_dict(response) return response @@ -1357,15 +1365,17 @@ async def batch_update_chunks_async( for path, value in updates.items(): chunk_to_update._apply_update(path, value) _requests.append( - glm.UpdateChunkRequest(chunk=chunk_to_update.to_dict(), update_mask=field_mask) + protos.UpdateChunkRequest( + chunk=chunk_to_update.to_dict(), update_mask=field_mask + ) ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response if isinstance(chunks, Iterable) and not isinstance(chunks, Mapping): for chunk in chunks: - if isinstance(chunk, glm.UpdateChunkRequest): + if isinstance(chunk, protos.UpdateChunkRequest): _requests.append(chunk) elif isinstance(chunk, tuple): # First element is name of chunk, second element contains updates @@ -1391,9 +1401,10 @@ async def batch_update_chunks_async( ) else: raise TypeError( - "Invalid input: The 'chunks' parameter must be a list of 'glm.UpdateChunkRequests', dictionaries, or tuples of dictionaries." + "Invalid input: The 'chunks' parameter must be a list of 'protos.UpdateChunkRequests', " + "dictionaries, or tuples of dictionaries." ) - request = glm.BatchUpdateChunksRequest(parent=self.name, requests=_requests) + request = protos.BatchUpdateChunksRequest(parent=self.name, requests=_requests) response = await client.batch_update_chunks(request, **request_options) response = type(response).to_dict(response) return response @@ -1420,7 +1431,7 @@ def delete_chunk( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) client.delete_chunk(request, **request_options) async def delete_chunk_async( @@ -1439,7 +1450,7 @@ async def delete_chunk_async( if "/" not in name: name = f"{self.name}/chunks/{name}" - request = glm.DeleteChunkRequest(name=name) + request = protos.DeleteChunkRequest(name=name) await client.delete_chunk(request, **request_options) def batch_delete_chunks( @@ -1461,18 +1472,19 @@ def batch_delete_chunks( if client is None: client = get_default_retriever_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) async def batch_delete_chunks_async( @@ -1488,18 +1500,19 @@ async def batch_delete_chunks_async( if client is None: client = get_default_retriever_async_client() - if all(isinstance(x, glm.DeleteChunkRequest) for x in chunks): - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=chunks) + if all(isinstance(x, protos.DeleteChunkRequest) for x in chunks): + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=chunks) await client.batch_delete_chunks(request, **request_options) elif isinstance(chunks, Iterable): _request_list = [] for chunk_name in chunks: - _request_list.append(glm.DeleteChunkRequest(name=chunk_name)) - request = glm.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) + _request_list.append(protos.DeleteChunkRequest(name=chunk_name)) + request = protos.BatchDeleteChunksRequest(parent=self.name, requests=_request_list) await client.batch_delete_chunks(request, **request_options) else: raise ValueError( - "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, or multiple 'glm.DeleteChunkRequest's." + "Invalid operation: To delete chunks, you must pass in either the names of the chunks as an iterable, " + "or multiple 'protos.DeleteChunkRequest's." ) def to_dict(self) -> dict[str, Any]: @@ -1511,7 +1524,7 @@ def to_dict(self) -> dict[str, Any]: return result -def decode_chunk(chunk: glm.Chunk) -> Chunk: +def decode_chunk(chunk: protos.Chunk) -> Chunk: chunk = type(chunk).to_dict(chunk) idecode_time(chunk, "create_time") idecode_time(chunk, "update_time") @@ -1625,7 +1638,7 @@ def update( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) client.update_chunk(request, **request_options) return self @@ -1665,7 +1678,7 @@ async def update_async( field_mask.paths.append(path) for path, value in updates.items(): self._apply_update(path, value) - request = glm.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) + request = protos.UpdateChunkRequest(chunk=self.to_dict(), update_mask=field_mask) await client.update_chunk(request, **request_options) return self diff --git a/google/generativeai/types/safety_types.py b/google/generativeai/types/safety_types.py index c8368da7f..74da06e45 100644 --- a/google/generativeai/types/safety_types.py +++ b/google/generativeai/types/safety_types.py @@ -23,7 +23,7 @@ from typing_extensions import TypedDict -from google.ai import generativelanguage as glm +from google.generativeai import protos from google.generativeai import string_utils @@ -39,9 +39,9 @@ ] # These are basic python enums, it's okay to expose them -HarmProbability = glm.SafetyRating.HarmProbability -HarmBlockThreshold = glm.SafetySetting.HarmBlockThreshold -BlockedReason = glm.ContentFilter.BlockedReason +HarmProbability = protos.SafetyRating.HarmProbability +HarmBlockThreshold = protos.SafetySetting.HarmBlockThreshold +BlockedReason = protos.ContentFilter.BlockedReason import proto @@ -51,57 +51,57 @@ class HarmCategory(proto.Enum): Harm Categories supported by the gemini-family model """ - HARM_CATEGORY_UNSPECIFIED = glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value - HARM_CATEGORY_HARASSMENT = glm.HarmCategory.HARM_CATEGORY_HARASSMENT.value - HARM_CATEGORY_HATE_SPEECH = glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value - HARM_CATEGORY_SEXUALLY_EXPLICIT = glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value - HARM_CATEGORY_DANGEROUS_CONTENT = glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value + HARM_CATEGORY_UNSPECIFIED = protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED.value + HARM_CATEGORY_HARASSMENT = protos.HarmCategory.HARM_CATEGORY_HARASSMENT.value + HARM_CATEGORY_HATE_SPEECH = protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH.value + HARM_CATEGORY_SEXUALLY_EXPLICIT = protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT.value + HARM_CATEGORY_DANGEROUS_CONTENT = protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT.value HarmCategoryOptions = Union[str, int, HarmCategory] # fmt: off -_HARM_CATEGORIES: Dict[HarmCategoryOptions, glm.HarmCategory] = { - glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - HarmCategory.HARM_CATEGORY_UNSPECIFIED: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 0: glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "harm_category_unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - "unspecified": glm.HarmCategory.HARM_CATEGORY_UNSPECIFIED, +_HARM_CATEGORIES: Dict[HarmCategoryOptions, protos.HarmCategory] = { + protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + HarmCategory.HARM_CATEGORY_UNSPECIFIED: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + 0: protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "harm_category_unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, + "unspecified": protos.HarmCategory.HARM_CATEGORY_UNSPECIFIED, - 7: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - glm.HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - HarmCategory.HARM_CATEGORY_HARASSMENT: glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - "harm_category_harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - "harassment": glm.HarmCategory.HARM_CATEGORY_HARASSMENT, - - 8: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - HarmCategory.HARM_CATEGORY_HATE_SPEECH: glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'harm_category_hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate_speech': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - 'hate': glm.HarmCategory.HARM_CATEGORY_HATE_SPEECH, - - 9: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "harm_category_sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexually_explicit": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sexual": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - "sex": glm.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, - - 10: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous_content": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "harm_category_dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "dangerous": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - "danger": glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + 7: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + protos.HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + HarmCategory.HARM_CATEGORY_HARASSMENT: protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harm_category_harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + "harassment": protos.HarmCategory.HARM_CATEGORY_HARASSMENT, + + 8: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'harm_category_hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate_speech': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + 'hate': protos.HarmCategory.HARM_CATEGORY_HATE_SPEECH, + + 9: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "harm_category_sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexually_explicit": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sexual": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + "sex": protos.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, + + 10: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous_content": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "harm_category_dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "dangerous": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + "danger": protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, } # fmt: on -def to_harm_category(x: HarmCategoryOptions) -> glm.HarmCategory: +def to_harm_category(x: HarmCategoryOptions) -> protos.HarmCategory: if isinstance(x, str): x = x.lower() return _HARM_CATEGORIES[x] @@ -150,7 +150,7 @@ class ContentFilterDict(TypedDict): reason: BlockedReason message: str - __doc__ = string_utils.strip_oneof(glm.ContentFilter.__doc__) + __doc__ = string_utils.strip_oneof(protos.ContentFilter.__doc__) def convert_filters_to_enums( @@ -166,15 +166,15 @@ def convert_filters_to_enums( class SafetyRatingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory probability: HarmProbability - __doc__ = string_utils.strip_oneof(glm.SafetyRating.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyRating.__doc__) def convert_rating_to_enum(rating: dict) -> SafetyRatingDict: return { - "category": glm.HarmCategory(rating["category"]), + "category": protos.HarmCategory(rating["category"]), "probability": HarmProbability(rating["probability"]), } @@ -187,10 +187,10 @@ def convert_ratings_to_enum(ratings: Iterable[dict]) -> List[SafetyRatingDict]: class SafetySettingDict(TypedDict): - category: glm.HarmCategory + category: protos.HarmCategory threshold: HarmBlockThreshold - __doc__ = string_utils.strip_oneof(glm.SafetySetting.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetySetting.__doc__) class LooseSafetySettingDict(TypedDict): @@ -225,7 +225,7 @@ def to_easy_safety_dict(settings: SafetySettingOptions) -> EasySafetySettingDict else: # Iterable result = {} for setting in settings: - if isinstance(setting, glm.SafetySetting): + if isinstance(setting, protos.SafetySetting): result[to_harm_category(setting.category)] = to_block_threshold(setting.threshold) elif isinstance(setting, dict): result[to_harm_category(setting["category"])] = to_block_threshold( @@ -267,7 +267,7 @@ def normalize_safety_settings( def convert_setting_to_enum(setting: dict) -> SafetySettingDict: return { - "category": glm.HarmCategory(setting["category"]), + "category": protos.HarmCategory(setting["category"]), "threshold": HarmBlockThreshold(setting["threshold"]), } @@ -276,7 +276,7 @@ class SafetyFeedbackDict(TypedDict): rating: SafetyRatingDict setting: SafetySettingDict - __doc__ = string_utils.strip_oneof(glm.SafetyFeedback.__doc__) + __doc__ = string_utils.strip_oneof(protos.SafetyFeedback.__doc__) def convert_safety_feedback_to_enums( diff --git a/tests/test_answer.py b/tests/test_answer.py index 4128567f4..2669b207c 100644 --- a/tests/test_answer.py +++ b/tests/test_answer.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import answer from google.generativeai import types as genai_types @@ -47,14 +47,14 @@ def add_client_method(f): @add_client_method def generate_answer( - request: glm.GenerateAnswerRequest, + request: protos.GenerateAnswerRequest, **kwargs, - ) -> glm.GenerateAnswerResponse: + ) -> protos.GenerateAnswerResponse: self.observed_requests.append(request) - return glm.GenerateAnswerResponse( - answer=glm.Candidate( + return protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ) @@ -62,17 +62,23 @@ def generate_answer( def test_make_grounding_passages_mixed_types(self): inline_passages = [ "I am a chicken", - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ] x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -82,23 +88,29 @@ def test_make_grounding_passages_mixed_types(self): [ dict( testcase_name="grounding_passage", - inline_passages=glm.GroundingPassages( + inline_passages=protos.GroundingPassages( passages=[ { "id": "0", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "2", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), }, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ), ), dict( testcase_name="content_object", inline_passages=[ - glm.Content(parts=[glm.Part(text="I am a chicken")]), - glm.Content(parts=[glm.Part(text="I am a bird.")]), - glm.Content(parts=[glm.Part(text="I can fly!")]), + protos.Content(parts=[protos.Part(text="I am a chicken")]), + protos.Content(parts=[protos.Part(text="I am a bird.")]), + protos.Content(parts=[protos.Part(text="I can fly!")]), ], ), dict( @@ -109,13 +121,19 @@ def test_make_grounding_passages_mixed_types(self): ) def test_make_grounding_passages(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "0", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "1", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -133,27 +151,33 @@ def test_make_grounding_passages(self, inline_passages): dict( testcase_name="list_of_grounding_passages", inline_passages=[ - glm.GroundingPassage( - id="4", content=glm.Content(parts=[glm.Part(text="I am a chicken")]) + protos.GroundingPassage( + id="4", content=protos.Content(parts=[protos.Part(text="I am a chicken")]) ), - glm.GroundingPassage( - id="5", content=glm.Content(parts=[glm.Part(text="I am a bird.")]) + protos.GroundingPassage( + id="5", content=protos.Content(parts=[protos.Part(text="I am a bird.")]) ), - glm.GroundingPassage( - id="6", content=glm.Content(parts=[glm.Part(text="I can fly!")]) + protos.GroundingPassage( + id="6", content=protos.Content(parts=[protos.Part(text="I can fly!")]) ), ], ), ) def test_make_grounding_passages_different_id(self, inline_passages): x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ - {"id": "4", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "5", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "6", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + { + "id": "4", + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "5", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + {"id": "6", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ), x, @@ -167,16 +191,22 @@ def test_make_grounding_passages_key_strings(self): } x = answer._make_grounding_passages(inline_passages) - self.assertIsInstance(x, glm.GroundingPassages) + self.assertIsInstance(x, protos.GroundingPassages) self.assertEqual( - glm.GroundingPassages( + protos.GroundingPassages( passages=[ { "id": "first", - "content": glm.Content(parts=[glm.Part(text="I am a chicken")]), + "content": protos.Content(parts=[protos.Part(text="I am a chicken")]), + }, + { + "id": "second", + "content": protos.Content(parts=[protos.Part(text="I am a bird.")]), + }, + { + "id": "third", + "content": protos.Content(parts=[protos.Part(text="I can fly!")]), }, - {"id": "second", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "third", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, ] ), x, @@ -184,14 +214,14 @@ def test_make_grounding_passages_key_strings(self): def test_generate_answer_request(self): # Should be a list of contents to use to_contents() function. - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] inline_passages = ["I am a chicken", "I am a bird.", "I can fly!"] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -200,7 +230,7 @@ def test_generate_answer_request(self): ) self.assertEqual( - glm.GenerateAnswerRequest( + protos.GenerateAnswerRequest( model=DEFAULT_ANSWER_MODEL, contents=contents, inline_passages=grounding_passages ), x, @@ -208,13 +238,13 @@ def test_generate_answer_request(self): def test_generate_answer(self): # Test handling return value of generate_answer(). - contents = [glm.Content(parts=[glm.Part(text="I have wings.")])] + contents = [protos.Content(parts=[protos.Part(text="I have wings.")])] - grounding_passages = glm.GroundingPassages( + grounding_passages = protos.GroundingPassages( passages=[ - {"id": "0", "content": glm.Content(parts=[glm.Part(text="I am a chicken")])}, - {"id": "1", "content": glm.Content(parts=[glm.Part(text="I am a bird.")])}, - {"id": "2", "content": glm.Content(parts=[glm.Part(text="I can fly!")])}, + {"id": "0", "content": protos.Content(parts=[protos.Part(text="I am a chicken")])}, + {"id": "1", "content": protos.Content(parts=[protos.Part(text="I am a bird.")])}, + {"id": "2", "content": protos.Content(parts=[protos.Part(text="I can fly!")])}, ] ) @@ -225,13 +255,13 @@ def test_generate_answer(self): answer_style="ABSTRACTIVE", ) - self.assertIsInstance(a, glm.GenerateAnswerResponse) + self.assertIsInstance(a, protos.GenerateAnswerResponse) self.assertEqual( a, - glm.GenerateAnswerResponse( - answer=glm.Candidate( + protos.GenerateAnswerResponse( + answer=protos.Candidate( index=1, - content=(glm.Content(parts=[glm.Part(text="Demo answer.")])), + content=(protos.Content(parts=[protos.Part(text="Demo answer.")])), ), answerable_probability=0.500, ), diff --git a/tests/test_client.py b/tests/test_client.py index 0256edac3..0cc3e05eb 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,8 +4,10 @@ from absl.testing import absltest from absl.testing import parameterized -from google.api_core import client_options import google.ai.generativelanguage as glm + +from google.api_core import client_options +from google.generativeai import protos from google.generativeai import client diff --git a/tests/test_content.py b/tests/test_content.py index 5f22b93a1..3829ebc86 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -19,7 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import content_types import IPython.display import PIL.Image @@ -71,7 +71,7 @@ class UnitTests(parameterized.TestCase): ) def test_png_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -81,29 +81,29 @@ def test_png_to_blob(self, image): ) def test_jpg_to_blob(self, image): blob = content_types.image_to_blob(image) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/jpeg") self.assertStartsWith(blob.data, b"\xff\xd8\xff\xe0\x00\x10JFIF") @parameterized.named_parameters( ["BlobDict", {"mime_type": "image/png", "data": TEST_PNG_DATA}], - ["glm.Blob", glm.Blob(mime_type="image/png", data=TEST_PNG_DATA)], + ["protos.Blob", protos.Blob(mime_type="image/png", data=TEST_PNG_DATA)], ["Image", IPython.display.Image(filename=TEST_PNG_PATH)], ) def test_to_blob(self, example): blob = content_types.to_blob(example) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( ["dict", {"text": "Hello world!"}], - ["glm.Part", glm.Part(text="Hello world!")], + ["protos.Part", protos.Part(text="Hello world!")], ["str", "Hello world!"], ) def test_to_part(self, example): part = content_types.to_part(example) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -116,12 +116,12 @@ def test_to_part(self, example): ) def test_img_to_part(self, example): blob = content_types.to_part(example).inline_data - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ["list[parts]", [{"text": "Hello world!"}]], @@ -135,7 +135,7 @@ def test_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -147,12 +147,12 @@ def test_img_to_content(self, example): content = content_types.to_content(example) blob = content.parts[0].inline_data self.assertLen(content.parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @parameterized.named_parameters( - ["glm.Content", glm.Content(parts=[{"text": "Hello world!"}])], + ["protos.Content", protos.Content(parts=[{"text": "Hello world!"}])], ["ContentDict", {"parts": [{"text": "Hello world!"}]}], ["ContentDict-str", {"parts": ["Hello world!"]}], ) @@ -161,7 +161,7 @@ def test_strict_to_content(self, example): part = content.parts[0] self.assertLen(content.parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") @parameterized.named_parameters( @@ -176,7 +176,7 @@ def test_strict_to_contents_fails(self, examples): content_types.strict_to_content(examples) @parameterized.named_parameters( - ["glm.Content", [glm.Content(parts=[{"text": "Hello world!"}])]], + ["protos.Content", [protos.Content(parts=[{"text": "Hello world!"}])]], ["ContentDict", [{"parts": [{"text": "Hello world!"}]}]], ["ContentDict-unwraped", [{"parts": ["Hello world!"]}]], ["ContentDict+str-part", [{"parts": "Hello world!"}]], @@ -188,7 +188,7 @@ def test_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(part, glm.Part) + self.assertIsInstance(part, protos.Part) self.assertEqual(part.text, "Hello world!") def test_dict_to_content_fails(self): @@ -209,7 +209,7 @@ def test_img_to_contents(self, example): self.assertLen(contents, 1) self.assertLen(contents[0].parts, 1) - self.assertIsInstance(blob, glm.Blob) + self.assertIsInstance(blob, protos.Blob) self.assertEqual(blob.mime_type, "image/png") self.assertStartsWith(blob.data, b"\x89PNG") @@ -217,9 +217,9 @@ def test_img_to_contents(self, example): [ "FunctionLibrary", content_types.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -231,7 +231,7 @@ def test_img_to_contents(self, example): [ content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -239,11 +239,11 @@ def test_img_to_contents(self, example): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -268,7 +268,7 @@ def test_img_to_contents(self, example): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -278,7 +278,7 @@ def test_img_to_contents(self, example): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -288,17 +288,17 @@ def test_img_to_contents(self, example): "Tool", content_types.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -350,8 +350,8 @@ def test_img_to_contents(self, example): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -391,83 +391,83 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], - ["nullable_str", Union[str, None], glm.Schema(type=glm.Type.STRING, nullable=True)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], + ["nullable_str", Union[str, None], protos.Schema(type=protos.Type.STRING, nullable=True)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], [ "dataclass", ADataClass, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "nullable_dataclass", Union[ADataClass, None], - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, nullable=True, - properties={"a": {"type_": glm.Type.INTEGER}}, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ], [ "list_of_dataclass", list[ADataClass], - glm.Schema( + protos.Schema( type="ARRAY", - items=glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER}}, + items=protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER}}, ), ), ], [ "dataclass_with_nullable", ADataClassWithNullable, - glm.Schema( - type=glm.Type.OBJECT, - properties={"a": {"type_": glm.Type.INTEGER, "nullable": True}}, + protos.Schema( + type=protos.Type.OBJECT, + properties={"a": {"type_": protos.Type.INTEGER, "nullable": True}}, ), ], [ "dataclass_with_list", ADataClassWithList, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), ], [ "list_of_dataclass_with_list", list[ADataClassWithList], - glm.Schema( - items=glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + items=protos.Schema( + type=protos.Type.OBJECT, properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, ), type="ARRAY", @@ -476,31 +476,31 @@ def b(): [ "list_of_nullable", list[Union[int, None]], - glm.Schema( + protos.Schema( type="ARRAY", - items={"type_": glm.Type.INTEGER, "nullable": True}, + items={"type_": protos.Type.INTEGER, "nullable": True}, ), ], [ "TypedDict", ATypedDict, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), ], [ "nested", Nested, - glm.Schema( - type=glm.Type.OBJECT, + protos.Schema( + type=protos.Type.OBJECT, properties={ - "x": glm.Schema( - type=glm.Type.OBJECT, + "x": protos.Schema( + type=protos.Type.OBJECT, properties={ - "a": {"type_": glm.Type.INTEGER}, + "a": {"type_": protos.Type.INTEGER}, }, ), }, diff --git a/tests/test_discuss.py b/tests/test_discuss.py index 7db0a63d8..4e54cf754 100644 --- a/tests/test_discuss.py +++ b/tests/test_discuss.py @@ -16,7 +16,7 @@ import unittest.mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from google.generativeai import client @@ -37,18 +37,18 @@ def setUp(self): self.observed_request = None - self.mock_response = glm.GenerateMessageResponse( + self.mock_response = protos.GenerateMessageResponse( candidates=[ - glm.Message(content="a", author="1"), - glm.Message(content="b", author="1"), - glm.Message(content="c", author="1"), + protos.Message(content="a", author="1"), + protos.Message(content="b", author="1"), + protos.Message(content="c", author="1"), ], ) def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: self.observed_request = request response = copy.copy(self.mock_response) response.messages = request.prompt.messages @@ -60,22 +60,22 @@ def fake_generate_message( ["string", "Hello", ""], ["dict", {"content": "Hello"}, ""], ["dict_author", {"content": "Hello", "author": "me"}, "me"], - ["proto", glm.Message(content="Hello"), ""], - ["proto_author", glm.Message(content="Hello", author="me"), "me"], + ["proto", protos.Message(content="Hello"), ""], + ["proto_author", protos.Message(content="Hello", author="me"), "me"], ) def test_make_message(self, message, author): x = discuss._make_message(message) - self.assertIsInstance(x, glm.Message) + self.assertIsInstance(x, protos.Message) self.assertEqual("Hello", x.content) self.assertEqual(author, x.author) @parameterized.named_parameters( ["string", "Hello", ["Hello"]], ["dict", {"content": "Hello"}, ["Hello"]], - ["proto", glm.Message(content="Hello"), ["Hello"]], + ["proto", protos.Message(content="Hello"), ["Hello"]], [ "list", - ["hello0", {"content": "hello1"}, glm.Message(content="hello2")], + ["hello0", {"content": "hello1"}, protos.Message(content="hello2")], ["hello0", "hello1", "hello2"], ], ) @@ -90,15 +90,15 @@ def test_make_messages(self, messages, expected_contents): ["dict", {"input": "hello", "output": "goodbye"}], [ "proto", - glm.Example( - input=glm.Message(content="hello"), - output=glm.Message(content="goodbye"), + protos.Example( + input=protos.Message(content="hello"), + output=protos.Message(content="goodbye"), ), ], ) def test_make_example(self, example): x = discuss._make_example(example) - self.assertIsInstance(x, glm.Example) + self.assertIsInstance(x, protos.Example) self.assertEqual("hello", x.input.content) self.assertEqual("goodbye", x.output.content) return @@ -110,7 +110,7 @@ def test_make_example(self, example): "Hi", {"content": "Hello!"}, "what's your name?", - glm.Message(content="Dave, what's yours"), + protos.Message(content="Dave, what's yours"), ], ], [ @@ -145,15 +145,15 @@ def test_make_examples_from_example(self): @parameterized.named_parameters( ["str", "hello"], - ["message", glm.Message(content="hello")], + ["message", protos.Message(content="hello")], ["messages", ["hello"]], ["dict", {"messages": "hello"}], ["dict2", {"messages": ["hello"]}], - ["proto", glm.MessagePrompt(messages=[glm.Message(content="hello")])], + ["proto", protos.MessagePrompt(messages=[protos.Message(content="hello")])], ) def test_make_message_prompt_from_messages(self, prompt): x = discuss._make_message_prompt(prompt) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.messages[0].content, "hello") return @@ -181,15 +181,15 @@ def test_make_message_prompt_from_messages(self, prompt): [ "proto", [ - glm.MessagePrompt( + protos.MessagePrompt( context="you are a cat", examples=[ - glm.Example( - input=glm.Message(content="are you hungry?"), - output=glm.Message(content="meow!"), + protos.Example( + input=protos.Message(content="are you hungry?"), + output=protos.Message(content="meow!"), ) ], - messages=[glm.Message(content="hello")], + messages=[protos.Message(content="hello")], ) ], {}, @@ -197,7 +197,7 @@ def test_make_message_prompt_from_messages(self, prompt): ) def test_make_message_prompt_from_prompt(self, args, kwargs): x = discuss._make_message_prompt(*args, **kwargs) - self.assertIsInstance(x, glm.MessagePrompt) + self.assertIsInstance(x, protos.MessagePrompt) self.assertEqual(x.context, "you are a cat") self.assertEqual(x.examples[0].input.content, "are you hungry?") self.assertEqual(x.examples[0].output.content, "meow!") @@ -229,8 +229,8 @@ def test_make_generate_message_request_nested( } ) - self.assertIsInstance(request0, glm.GenerateMessageRequest) - self.assertIsInstance(request1, glm.GenerateMessageRequest) + self.assertIsInstance(request0, protos.GenerateMessageRequest) + self.assertIsInstance(request1, protos.GenerateMessageRequest) self.assertEqual(request0, request1) @parameterized.parameters( @@ -285,11 +285,13 @@ def test_reply(self, kwargs): response = response.reply("again") def test_receive_and_reply_with_filters(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe"), - glm.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.SAFETY, message="unsafe" + ), + protos.ContentFilter(reason=palm_safety_types.BlockedReason.OTHER), ], ) response = discuss.chat(messages="do filters work?") @@ -300,10 +302,12 @@ def test_receive_and_reply_with_filters(self): self.assertEqual(filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) self.assertEqual(filters[0]["message"], "unsafe") - self.mock_response = glm.GenerateMessageResponse( - candidates=[glm.Message(content="a", author="1")], + self.mock_response = protos.GenerateMessageResponse( + candidates=[protos.Message(content="a", author="1")], filters=[ - glm.ContentFilter(reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED) + protos.ContentFilter( + reason=palm_safety_types.BlockedReason.BLOCKED_REASON_UNSPECIFIED + ) ], ) @@ -317,7 +321,7 @@ def test_receive_and_reply_with_filters(self): ) def test_chat_citations(self): - self.mock_response = mock_response = glm.GenerateMessageResponse( + self.mock_response = mock_response = protos.GenerateMessageResponse( candidates=[ { "content": "Hello google!", diff --git a/tests/test_discuss_async.py b/tests/test_discuss_async.py index 7e1f7947c..d35d03525 100644 --- a/tests/test_discuss_async.py +++ b/tests/test_discuss_async.py @@ -17,7 +17,7 @@ from typing import Any import unittest -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import discuss from absl.testing import absltest @@ -31,14 +31,14 @@ async def test_chat_async(self): observed_request = None async def fake_generate_message( - request: glm.GenerateMessageRequest, + request: protos.GenerateMessageRequest, **kwargs, - ) -> glm.GenerateMessageResponse: + ) -> protos.GenerateMessageResponse: nonlocal observed_request observed_request = request - return glm.GenerateMessageResponse( + return protos.GenerateMessageResponse( candidates=[ - glm.Message( + protos.Message( author="1", content="Why did the chicken cross the road?", ) @@ -59,17 +59,17 @@ async def fake_generate_message( self.assertEqual( observed_request, - glm.GenerateMessageRequest( + protos.GenerateMessageRequest( model="models/bard", - prompt=glm.MessagePrompt( + prompt=protos.MessagePrompt( context="Example Prompt", examples=[ - glm.Example( - input=glm.Message(content="Example from human"), - output=glm.Message(content="Example response from AI"), + protos.Example( + input=protos.Message(content="Example from human"), + output=protos.Message(content="Example response from AI"), ) ], - messages=[glm.Message(author="0", content="Tell me a joke")], + messages=[protos.Message(author="0", content="Tell me a joke")], ), temperature=0.75, candidate_count=1, diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 5f6aa8d89..a208a4743 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -45,20 +45,20 @@ def add_client_method(f): @add_client_method def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) def test_embed_content(self): @@ -68,8 +68,9 @@ def test_embed_content(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_embedding_async.py b/tests/test_embedding_async.py index d4ca16c08..367cf7ded 100644 --- a/tests/test_embedding_async.py +++ b/tests/test_embedding_async.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import embedding @@ -44,20 +44,20 @@ def add_client_method(f): @add_client_method async def embed_content( - request: glm.EmbedContentRequest, + request: protos.EmbedContentRequest, **kwargs, - ) -> glm.EmbedContentResponse: + ) -> protos.EmbedContentResponse: self.observed_requests.append(request) - return glm.EmbedContentResponse(embedding=glm.ContentEmbedding(values=[1, 2, 3])) + return protos.EmbedContentResponse(embedding=protos.ContentEmbedding(values=[1, 2, 3])) @add_client_method async def batch_embed_contents( - request: glm.BatchEmbedContentsRequest, + request: protos.BatchEmbedContentsRequest, **kwargs, - ) -> glm.BatchEmbedContentsResponse: + ) -> protos.BatchEmbedContentsResponse: self.observed_requests.append(request) - return glm.BatchEmbedContentsResponse( - embeddings=[glm.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) + return protos.BatchEmbedContentsResponse( + embeddings=[protos.ContentEmbedding(values=[1, 2, 3])] * len(request.requests) ) async def test_embed_content_async(self): @@ -67,8 +67,9 @@ async def test_embed_content_async(self): self.assertIsInstance(emb, dict) self.assertEqual( self.observed_requests[-1], - glm.EmbedContentRequest( - model=DEFAULT_EMB_MODEL, content=glm.Content(parts=[glm.Part(text="What are you?")]) + protos.EmbedContentRequest( + model=DEFAULT_EMB_MODEL, + content=protos.Content(parts=[protos.Part(text="What are you?")]), ), ) self.assertIsInstance(emb["embedding"][0], float) diff --git a/tests/test_files.py b/tests/test_files.py index 333ec1e2a..7d9139450 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -22,10 +22,10 @@ import pathlib import google -import google.ai.generativelanguage as glm import google.generativeai as genai from google.generativeai import client as client_lib +from google.generativeai import protos from absl.testing import parameterized @@ -43,7 +43,7 @@ def create_file( name: Union[str, None] = None, display_name: Union[str, None] = None, resumable: bool = True, - ) -> glm.File: + ) -> protos.File: self.observed_requests.append( dict( path=path, @@ -57,24 +57,24 @@ def create_file( def get_file( self, - request: glm.GetFileRequest, + request: protos.GetFileRequest, **kwargs, - ) -> glm.File: + ) -> protos.File: self.observed_requests.append(request) return self.responses["get_file"].pop(0) def list_files( self, - request: glm.ListFilesRequest, + request: protos.ListFilesRequest, **kwargs, - ) -> Iterable[glm.File]: + ) -> Iterable[protos.File]: self.observed_requests.append(request) for f in self.responses["list_files"].pop(0): yield f def delete_file( self, - request: glm.DeleteFileRequest, + request: protos.DeleteFileRequest, **kwargs, ): self.observed_requests.append(request) @@ -97,7 +97,7 @@ def responses(self): def test_video_metadata(self): self.responses["create_file"].append( - glm.File( + protos.File( uri="https://test", state="ACTIVE", video_metadata=dict(video_duration=datetime.timedelta(seconds=30)), @@ -108,7 +108,8 @@ def test_video_metadata(self): f = genai.upload_file(path="dummy") self.assertEqual(google.rpc.status_pb2.Status(code=7, message="ok?"), f.error) self.assertEqual( - glm.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), f.video_metadata + protos.VideoMetadata(dict(video_duration=datetime.timedelta(seconds=30))), + f.video_metadata, ) @parameterized.named_parameters( @@ -123,11 +124,11 @@ def test_video_metadata(self): ), dict( testcase_name="FileData", - file_data=glm.FileData(file_uri="https://test_uri"), + file_data=protos.FileData(file_uri="https://test_uri"), ), dict( - testcase_name="glm.File", - file_data=glm.File(uri="https://test_uri"), + testcase_name="protos.File", + file_data=protos.File(uri="https://test_uri"), ), dict( testcase_name="file_types.File", @@ -137,4 +138,4 @@ def test_video_metadata(self): ) def test_to_file_data(self, file_data): file_data = file_types.to_file_data(file_data) - self.assertEqual(glm.FileData(file_uri="https://test_uri"), file_data) + self.assertEqual(protos.FileData(file_uri="https://test_uri"), file_data) diff --git a/tests/test_generation.py b/tests/test_generation.py index b256a1029..828577d21 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -5,7 +5,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai.types import generation_types @@ -24,9 +24,11 @@ class Person(TypedDict): class UnitTests(parameterized.TestCase): @parameterized.named_parameters( [ - "glm.GenerationConfig", - glm.GenerationConfig( - temperature=0.1, stop_sequences=["end"], response_schema=glm.Schema(type="STRING") + "protos.GenerationConfig", + protos.GenerationConfig( + temperature=0.1, + stop_sequences=["end"], + response_schema=protos.Schema(type="STRING"), ), ], [ @@ -48,15 +50,15 @@ def test_to_generation_config(self, config): def test_join_citation_metadatas(self): citations = [ - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=21, uri="https://google.com"), ] ), - glm.CitationMetadata( + protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=3, end_index=33, uri="https://google.com"), - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), + protos.CitationSource(start_index=3, end_index=33, uri="https://google.com"), + protos.CitationSource(start_index=55, end_index=92, uri="https://google.com"), ] ), ] @@ -74,14 +76,14 @@ def test_join_citation_metadatas(self): def test_join_safety_ratings_list(self): ratings = [ [ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="MEDIUM"), ], [ - glm.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), - glm.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), - glm.SafetyRating( + protos.SafetyRating(category="HARM_CATEGORY_DEROGATORY", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_SEXUAL", probability="LOW"), + protos.SafetyRating( category="HARM_CATEGORY_DANGEROUS", probability="HIGH", blocked=True, @@ -101,14 +103,14 @@ def test_join_safety_ratings_list(self): def test_join_contents(self): contents = [ - glm.Content(role="assistant", parts=[glm.Part(text="Tell me a story about a ")]), - glm.Content( + protos.Content(role="assistant", parts=[protos.Part(text="Tell me a story about a ")]), + protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - glm.Content( + protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!"))], ), ] result = generation_types._join_contents(contents) @@ -126,7 +128,8 @@ def test_many_join_contents(self): import string contents = [ - glm.Content(role="assistant", parts=[glm.Part(text=a)]) for a in string.ascii_lowercase + protos.Content(role="assistant", parts=[protos.Part(text=a)]) + for a in string.ascii_lowercase ] result = generation_types._join_contents(contents) @@ -139,41 +142,53 @@ def test_many_join_contents(self): def test_join_candidates(self): candidates = [ - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="Tell me a story about a ")], + parts=[protos.Part(text="Tell me a story about a ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=85, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=85, uri="https://google.com" + ), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(text="magic backpack that looks like this: ")], + parts=[protos.Part(text="magic backpack that looks like this: ")], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), ), - glm.Candidate( + protos.Candidate( index=0, - content=glm.Content( + content=protos.Content( role="assistant", - parts=[glm.Part(inline_data=glm.Blob(mime_type="image/png", data=b"DATA!"))], + parts=[ + protos.Part(inline_data=protos.Blob(mime_type="image/png", data=b"DATA!")) + ], ), - citation_metadata=glm.CitationMetadata( + citation_metadata=protos.CitationMetadata( citation_sources=[ - glm.CitationSource(start_index=55, end_index=92, uri="https://google.com"), - glm.CitationSource(start_index=3, end_index=21, uri="https://google.com"), + protos.CitationSource( + start_index=55, end_index=92, uri="https://google.com" + ), + protos.CitationSource( + start_index=3, end_index=21, uri="https://google.com" + ), ] ), finish_reason="STOP", @@ -213,17 +228,17 @@ def test_join_candidates(self): def test_join_prompt_feedbacks(self): feedbacks = [ - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback(), - glm.GenerateContentResponse.PromptFeedback( + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback(), + protos.GenerateContentResponse.PromptFeedback( safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), + protos.SafetyRating(category="HARM_CATEGORY_MEDICAL", probability="HIGH"), ] ), ] @@ -396,23 +411,23 @@ def test_join_prompt_feedbacks(self): ] def test_join_candidates(self): - candidate_lists = [[glm.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] + candidate_lists = [[protos.Candidate(c) for c in cl] for cl in self.CANDIDATE_LISTS] result = generation_types._join_candidate_lists(candidate_lists) self.assertEqual(self.MERGED_CANDIDATES, [type(r).to_dict(r) for r in result]) def test_join_chunks(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] - chunks[0].prompt_feedback = glm.GenerateContentResponse.PromptFeedback( + chunks[0].prompt_feedback = protos.GenerateContentResponse.PromptFeedback( block_reason="SAFETY", safety_ratings=[ - glm.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), + protos.SafetyRating(category="HARM_CATEGORY_DANGEROUS", probability="LOW"), ], ) result = generation_types._join_chunks(chunks) - expected = glm.GenerateContentResponse( + expected = protos.GenerateContentResponse( { "candidates": self.MERGED_CANDIDATES, "prompt_feedback": { @@ -431,7 +446,7 @@ def test_join_chunks(self): self.assertEqual(type(expected).to_dict(expected), type(result).to_dict(expected)) def test_generate_content_response_iterator_end_to_end(self): - chunks = [glm.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] + chunks = [protos.GenerateContentResponse(candidates=cl) for cl in self.CANDIDATE_LISTS] merged = generation_types._join_chunks(chunks) response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -453,7 +468,7 @@ def test_generate_content_response_iterator_end_to_end(self): def test_generate_content_response_multiple_iterators(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in string.ascii_lowercase ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -483,7 +498,7 @@ def test_generate_content_response_multiple_iterators(self): def test_generate_content_response_resolve(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -497,7 +512,7 @@ def test_generate_content_response_resolve(self): self.assertEqual(response.candidates[0].content.parts[0].text, "abcd") def test_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -511,7 +526,7 @@ def test_generate_content_response_from_response(self): ) def test_repr_for_generate_content_response_from_response(self): - raw_response = glm.GenerateContentResponse( + raw_response = protos.GenerateContentResponse( {"candidates": [{"content": {"parts": [{"text": "Hello world!"}]}}]} ) response = generation_types.GenerateContentResponse.from_response(raw_response) @@ -523,7 +538,7 @@ def test_repr_for_generate_content_response_from_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -542,7 +557,7 @@ def test_repr_for_generate_content_response_from_response(self): def test_repr_for_generate_content_response_from_iterator(self): chunks = [ - glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) + protos.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": a}]}}]}) for a in "abcd" ] response = generation_types.GenerateContentResponse.from_iterator(iter(chunks)) @@ -554,7 +569,7 @@ def test_repr_for_generate_content_response_from_iterator(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -573,35 +588,35 @@ def test_repr_for_generate_content_response_from_iterator(self): @parameterized.named_parameters( [ - "glm.Schema", - glm.Schema(type="STRING"), - glm.Schema(type="STRING"), + "protos.Schema", + protos.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "SchemaDict", {"type": "STRING"}, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], [ "str", str, - glm.Schema(type="STRING"), + protos.Schema(type="STRING"), ], - ["list_of_str", list[str], glm.Schema(type="ARRAY", items=glm.Schema(type="STRING"))], + ["list_of_str", list[str], protos.Schema(type="ARRAY", items=protos.Schema(type="STRING"))], [ "fancy", Person, - glm.Schema( + protos.Schema( type="OBJECT", properties=dict( - name=glm.Schema(type="STRING"), - favorite_color=glm.Schema(type="STRING"), - birthday=glm.Schema( + name=protos.Schema(type="STRING"), + favorite_color=protos.Schema(type="STRING"), + birthday=protos.Schema( type="OBJECT", properties=dict( - day=glm.Schema(type="INTEGER"), - month=glm.Schema(type="INTEGER"), - year=glm.Schema(type="INTEGER"), + day=protos.Schema(type="INTEGER"), + month=protos.Schema(type="INTEGER"), + year=protos.Schema(type="INTEGER"), ), ), ), diff --git a/tests/test_generative_models.py b/tests/test_generative_models.py index 4a0f86991..0ece77e94 100644 --- a/tests/test_generative_models.py +++ b/tests/test_generative_models.py @@ -7,7 +7,7 @@ import unittest.mock from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types @@ -23,20 +23,20 @@ TEST_IMAGE_DATA = TEST_IMAGE_PATH.read_bytes() -def noop(x: int): - return x +def simple_part(text: str) -> protos.Content: + return protos.Content({"parts": [{"text": text}]}) -def simple_part(text: str) -> glm.Content: - return glm.Content({"parts": [{"text": text}]}) +def noop(x: int): + return x -def iter_part(texts: Iterable[str]) -> glm.Content: - return glm.Content({"parts": [{"text": t} for t in texts]}) +def iter_part(texts: Iterable[str]) -> protos.Content: + return protos.Content({"parts": [{"text": t} for t in texts]}) -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse({"candidates": [{"content": simple_part(text)}]}) class MockGenerativeServiceClient: @@ -48,10 +48,10 @@ def __init__(self, test): def generate_content( self, - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.test.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.test.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["generate_content"].pop(0) @@ -59,9 +59,9 @@ def generate_content( def stream_generate_content( self, - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["stream_generate_content"].pop(0) @@ -69,9 +69,9 @@ def stream_generate_content( def count_tokens( self, - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) self.observed_kwargs.append(kwargs) response = self.responses["count_tokens"].pop(0) @@ -149,9 +149,9 @@ def test_image(self, content): generation_types.GenerationConfig(temperature=0.5), ], [ - "glm", - glm.GenerationConfig(temperature=0.0), - glm.GenerationConfig(temperature=0.5), + "protos", + protos.GenerationConfig(temperature=0.0), + protos.GenerationConfig(temperature=0.5), ], ) def test_generation_config_overwrite(self, config1, config2): @@ -176,8 +176,8 @@ def test_generation_config_overwrite(self, config1, config2): "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ @@ -187,15 +187,15 @@ def test_generation_config_overwrite(self, config1, config2): [ "object", [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ), ], [ - glm.SafetySetting( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + protos.SafetySetting( + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ), ], ], @@ -214,22 +214,22 @@ def test_safety_overwrite(self, safe1, safe2): danger = [ s for s in self.observed_requests[-1].safety_settings - if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT ] self.assertEqual( danger[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, + protos.SafetySetting.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE, ) _ = model.generate_content("hello", safety_settings=safe2) danger = [ s for s in self.observed_requests[-1].safety_settings - if s.category == glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT + if s.category == protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT ] self.assertEqual( danger[0].threshold, - glm.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, + protos.SafetySetting.HarmBlockThreshold.BLOCK_ONLY_HIGH, ) def test_stream_basic(self): @@ -263,7 +263,7 @@ def test_stream_lookahead(self): def test_stream_prompt_feedback_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -276,7 +276,7 @@ def test_stream_prompt_feedback_blocked(self): self.assertEqual( response.prompt_feedback.block_reason, - glm.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, + protos.GenerateContentResponse.PromptFeedback.BlockReason.SAFETY, ) with self.assertRaises(generation_types.BlockedPromptException): @@ -285,20 +285,20 @@ def test_stream_prompt_feedback_blocked(self): def test_stream_prompt_feedback_not_blocked(self): chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": { "safety_ratings": [ { - "category": glm.HarmCategory.HARM_CATEGORY_DANGEROUS, - "probability": glm.SafetyRating.HarmProbability.NEGLIGIBLE, + "category": protos.HarmCategory.HARM_CATEGORY_DANGEROUS, + "probability": protos.SafetyRating.HarmProbability.NEGLIGIBLE, } ] }, "candidates": [{"content": {"parts": [{"text": "first"}]}}], } ), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"content": {"parts": [{"text": " second"}]}}], } @@ -311,7 +311,7 @@ def test_stream_prompt_feedback_not_blocked(self): self.assertEqual( response.prompt_feedback.safety_ratings[0].category, - glm.HarmCategory.HARM_CATEGORY_DANGEROUS, + protos.HarmCategory.HARM_CATEGORY_DANGEROUS, ) text = "".join(chunk.text for chunk in response) @@ -544,7 +544,7 @@ def no_throw(): def test_chat_prompt_blocked(self): self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -562,7 +562,7 @@ def test_chat_prompt_blocked(self): def test_chat_candidate_blocked(self): # I feel like chat needs a .last so you can look at the partial results. self.responses["generate_content"] = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -582,7 +582,7 @@ def test_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -669,9 +669,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -698,9 +698,9 @@ def test_tools(self): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -773,7 +773,7 @@ def test_system_instruction(self, instruction, expected_instr): ) def test_count_tokens_smoke(self, kwargs): si = kwargs.pop("system_instruction", None) - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision", system_instruction=si) response = model.count_tokens(**kwargs) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) @@ -840,7 +840,7 @@ def test_repr_for_unary_non_streamed_response(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -873,7 +873,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -898,7 +898,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -927,7 +927,7 @@ def test_repr_for_streaming_start_to_finish(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -951,7 +951,7 @@ def test_repr_for_streaming_start_to_finish(self): def test_repr_error_info_for_stream_prompt_feedback_blocked(self): # response._error => BlockedPromptException chunks = [ - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "prompt_feedback": {"block_reason": "SAFETY"}, } @@ -969,7 +969,7 @@ def test_repr_error_info_for_stream_prompt_feedback_blocked(self): GenerateContentResponse( done=False, iterator=, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "prompt_feedback": { "block_reason": "SAFETY" } @@ -1019,7 +1019,7 @@ def no_throw(): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1049,7 +1049,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): simple_response("a"), simple_response("b"), simple_response("c"), - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "SAFETY"}], } @@ -1078,7 +1078,7 @@ def test_repr_error_info_for_chat_streaming_unexpected_stop(self): GenerateContentResponse( done=True, iterator=None, - result=glm.GenerateContentResponse({ + result=protos.GenerateContentResponse({ "candidates": [ { "content": { @@ -1141,7 +1141,7 @@ def test_repr_for_multi_turn_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'first'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), glm.Content({'parts': [{'text': 'second'}], 'role': 'model'}), glm.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), glm.Content({'parts': [{'text': 'third'}], 'role': 'model'})] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'first'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'I also like this image.'}, {'inline_data': {'data': 'iVBORw0KGgoA...AAElFTkSuQmCC', 'mime_type': 'image/png'}}], 'role': 'user'}), protos.Content({'parts': [{'text': 'second'}], 'role': 'model'}), protos.Content({'parts': [{'text': 'What things do I like?.'}], 'role': 'user'}), protos.Content({'parts': [{'text': 'third'}], 'role': 'model'})] )""" ) self.assertEqual(expected, result) @@ -1169,7 +1169,7 @@ def test_repr_for_incomplete_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1185,7 +1185,7 @@ def test_repr_for_broken_streaming_chat(self): for chunk in [ simple_response("first"), # FinishReason.SAFETY = 3 - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [ {"finish_reason": 3, "content": {"parts": [{"text": "second"}]}} @@ -1213,7 +1213,7 @@ def test_repr_for_broken_streaming_chat(self): tools=None, system_instruction=None, ), - history=[glm.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] + history=[protos.Content({'parts': [{'text': 'I really like fantasy books.'}], 'role': 'user'}), ] )""" ) self.assertEqual(expected, result) @@ -1224,7 +1224,7 @@ def test_repr_for_system_instruction(self): self.assertIn("system_instruction='Be excellent.'", result) def test_count_tokens_called_with_request_options(self): - self.responses["count_tokens"].append(glm.CountTokensResponse()) + self.responses["count_tokens"].append(protos.CountTokensResponse(total_tokens=7)) request_options = {"timeout": 120} model = generative_models.GenerativeModel("gemini-pro-vision") @@ -1234,7 +1234,7 @@ def test_count_tokens_called_with_request_options(self): def test_chat_with_request_options(self): self.responses["generate_content"].append( - glm.GenerateContentResponse( + protos.GenerateContentResponse( { "candidates": [{"finish_reason": "STOP"}], } diff --git a/tests/test_generative_models_async.py b/tests/test_generative_models_async.py index 2c465d1d3..03055ffb3 100644 --- a/tests/test_generative_models_async.py +++ b/tests/test_generative_models_async.py @@ -24,14 +24,16 @@ from google.generativeai import client as client_lib from google.generativeai import generative_models from google.generativeai.types import content_types -import google.ai.generativelanguage as glm +from google.generativeai import protos from absl.testing import absltest from absl.testing import parameterized -def simple_response(text: str) -> glm.GenerateContentResponse: - return glm.GenerateContentResponse({"candidates": [{"content": {"parts": [{"text": text}]}}]}) +def simple_response(text: str) -> protos.GenerateContentResponse: + return protos.GenerateContentResponse( + {"candidates": [{"content": {"parts": [{"text": text}]}}]} + ) class AsyncTests(parameterized.TestCase, unittest.IsolatedAsyncioTestCase): @@ -50,28 +52,28 @@ def add_client_method(f): @add_client_method async def generate_content( - request: glm.GenerateContentRequest, + request: protos.GenerateContentRequest, **kwargs, - ) -> glm.GenerateContentResponse: - self.assertIsInstance(request, glm.GenerateContentRequest) + ) -> protos.GenerateContentResponse: + self.assertIsInstance(request, protos.GenerateContentRequest) self.observed_requests.append(request) response = self.responses["generate_content"].pop(0) return response @add_client_method async def stream_generate_content( - request: glm.GetModelRequest, + request: protos.GetModelRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["stream_generate_content"].pop(0) return response @add_client_method async def count_tokens( - request: glm.CountTokensRequest, + request: protos.CountTokensRequest, **kwargs, - ) -> Iterable[glm.GenerateContentResponse]: + ) -> Iterable[protos.GenerateContentResponse]: self.observed_requests.append(request) response = self.responses["count_tokens"].pop(0) return response @@ -140,9 +142,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_FunctionCallingConfig", + testcase_name="test_protos.FunctionCallingConfig", tool_config={ - "function_calling_config": glm.FunctionCallingConfig( + "function_calling_config": protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.AUTO ) }, @@ -169,9 +171,9 @@ async def responses(): }, ), dict( - testcase_name="test_glm_ToolConfig", - tool_config=glm.ToolConfig( - function_calling_config=glm.FunctionCallingConfig( + testcase_name="test_protos.ToolConfig", + tool_config=protos.ToolConfig( + function_calling_config=protos.FunctionCallingConfig( mode=content_types.FunctionCallingMode.NONE ) ), @@ -211,7 +213,7 @@ async def test_tool_config(self, tool_config, expected_tool_config): ["contents", [{"role": "user", "parts": ["hello"]}]], ) async def test_count_tokens_smoke(self, contents): - self.responses["count_tokens"] = [glm.CountTokensResponse(total_tokens=7)] + self.responses["count_tokens"] = [protos.CountTokensResponse(total_tokens=7)] model = generative_models.GenerativeModel("gemini-pro-vision") response = await model.count_tokens_async(contents) self.assertEqual(type(response).to_dict(response), {"total_tokens": 7}) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 0c2de7f29..f060caf88 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -19,7 +19,7 @@ from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import client from google.generativeai import models @@ -35,15 +35,15 @@ def __init__(self, test): def get_model( self, - request: Union[glm.GetModelRequest, None] = None, + request: Union[protos.GetModelRequest, None] = None, *, name=None, timeout=None, retry=None - ) -> glm.Model: + ) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.test.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.test.assertIsInstance(request, protos.GetModelRequest) self.test.observed_requests.append(request) self.test.observed_timeout.append(timeout) self.test.observed_retry.append(retry) @@ -75,7 +75,7 @@ def setUp(self): ], ) def test_get_model(self, request_options, expected_timeout, expected_retry): - self.responses = {"get_model": glm.Model(name="models/fake-bison-001")} + self.responses = {"get_model": protos.Model(name="models/fake-bison-001")} _ = models.get_model("models/fake-bison-001", request_options=request_options) diff --git a/tests/test_models.py b/tests/test_models.py index f39ed3a2c..23f80913a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.api_core import operation from google.generativeai import models @@ -45,7 +45,7 @@ def setUp(self): client._client_manager.clients["model"] = self.client # TODO(markdaoust): Check if typechecking works better if wee define this as a - # subclass of `glm.ModelServiceClient`, would pyi files for `glm` help? + # subclass of `glm.ModelServiceClient`, would pyi files for `glm`. help? def add_client_method(f): name = f.__name__ setattr(self.client, name, f) @@ -55,63 +55,65 @@ def add_client_method(f): self.responses = {} @add_client_method - def get_model(request: Union[glm.GetModelRequest, None] = None, *, name=None) -> glm.Model: + def get_model( + request: Union[protos.GetModelRequest, None] = None, *, name=None + ) -> protos.Model: if request is None: - request = glm.GetModelRequest(name=name) - self.assertIsInstance(request, glm.GetModelRequest) + request = protos.GetModelRequest(name=name) + self.assertIsInstance(request, protos.GetModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_model"]) return response @add_client_method def get_tuned_model( - request: Union[glm.GetTunedModelRequest, None] = None, + request: Union[protos.GetTunedModelRequest, None] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def list_models( - request: Union[glm.ListModelsRequest, None] = None, + request: Union[protos.ListModelsRequest, None] = None, *, page_size=None, page_token=None, **kwargs, - ) -> glm.ListModelsResponse: + ) -> protos.ListModelsResponse: if request is None: - request = glm.ListModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListModelsRequest) + request = protos.ListModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListModelsRequest) self.observed_requests.append(request) response = self.responses["list_models"] return (item for item in response) @add_client_method def list_tuned_models( - request: glm.ListTunedModelsRequest = None, + request: protos.ListTunedModelsRequest = None, *, page_size=None, page_token=None, **kwargs, - ) -> Iterable[glm.TunedModel]: + ) -> Iterable[protos.TunedModel]: if request is None: - request = glm.ListTunedModelsRequest(page_size=page_size, page_token=page_token) - self.assertIsInstance(request, glm.ListTunedModelsRequest) + request = protos.ListTunedModelsRequest(page_size=page_size, page_token=page_token) + self.assertIsInstance(request, protos.ListTunedModelsRequest) self.observed_requests.append(request) response = self.responses["list_tuned_models"] return (item for item in response) @add_client_method def update_tuned_model( - request: glm.UpdateTunedModelRequest, + request: protos.UpdateTunedModelRequest, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: self.observed_requests.append(request) response = self.responses.get("update_tuned_model", None) if response is None: @@ -120,7 +122,7 @@ def update_tuned_model( @add_client_method def delete_tuned_model(name): - request = glm.DeleteTunedModelRequest(name=name) + request = protos.DeleteTunedModelRequest(name=name) self.observed_requests.append(request) response = True return response @@ -130,26 +132,26 @@ def create_tuned_model( request, **kwargs, ): - request = glm.CreateTunedModelRequest(request) + request = protos.CreateTunedModelRequest(request) self.observed_requests.append(request) return self.responses["create_tuned_model"] def test_decode_tuned_model_time_round_trip(self): example_dt = datetime.datetime(2000, 1, 2, 3, 4, 5, 600_000, pytz.UTC) - tuned_model = glm.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) + tuned_model = protos.TunedModel(name="tunedModels/house-mouse-001", create_time=example_dt) tuned_model = model_types.decode_tuned_model(tuned_model) self.assertEqual(tuned_model.create_time, example_dt) @parameterized.named_parameters( ["simple", "models/fake-bison-001"], ["simple-tuned", "tunedModels/my-pig-001"], - ["model-instance", glm.Model(name="models/fake-bison-001")], - ["tuned-model-instance", glm.TunedModel(name="tunedModels/my-pig-001")], + ["model-instance", protos.Model(name="models/fake-bison-001")], + ["tuned-model-instance", protos.TunedModel(name="tunedModels/my-pig-001")], ) def test_get_model(self, name): self.responses = { - "get_model": glm.Model(name="models/fake-bison-001"), - "get_tuned_model": glm.TunedModel(name="tunedModels/my-pig-001"), + "get_model": protos.Model(name="models/fake-bison-001"), + "get_tuned_model": protos.TunedModel(name="tunedModels/my-pig-001"), } model = models.get_model(name) @@ -160,7 +162,7 @@ def test_get_model(self, name): @parameterized.named_parameters( ["simple", "mystery-bison-001"], - ["model-instance", glm.Model(name="how?-bison-001")], + ["model-instance", protos.Model(name="how?-bison-001")], ) def test_fail_with_unscoped_model_name(self, name): with self.assertRaises(ValueError): @@ -170,9 +172,9 @@ def test_list_models(self): # The low level lib wraps the response in an iterable, so this is a fair test. self.responses = { "list_models": [ - glm.Model(name="models/fake-bison-001"), - glm.Model(name="models/fake-bison-002"), - glm.Model(name="models/fake-bison-003"), + protos.Model(name="models/fake-bison-001"), + protos.Model(name="models/fake-bison-002"), + protos.Model(name="models/fake-bison-003"), ] } @@ -185,9 +187,9 @@ def test_list_tuned_models(self): self.responses = { # The low level lib wraps the response in an iterable, so this is a fair test. "list_tuned_models": [ - glm.TunedModel(name="tunedModels/my-pig-001"), - glm.TunedModel(name="tunedModels/my-pig-002"), - glm.TunedModel(name="tunedModels/my-pig-003"), + protos.TunedModel(name="tunedModels/my-pig-001"), + protos.TunedModel(name="tunedModels/my-pig-002"), + protos.TunedModel(name="tunedModels/my-pig-003"), ] } found_models = list(models.list_tuned_models()) @@ -197,8 +199,8 @@ def test_list_tuned_models(self): @parameterized.named_parameters( [ - "edited-glm-model", - glm.TunedModel( + "edited-protos.model", + protos.TunedModel( name="tunedModels/my-pig-001", description="Trained on my data", ), @@ -211,7 +213,7 @@ def test_list_tuned_models(self): ], ) def test_update_tuned_model_basics(self, tuned_model, updates): - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/my-pig-001") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/my-pig-001") # No self.responses['update_tuned_model'] the mock just returns the input. updated_model = models.update_tuned_model(tuned_model, updates) updated_model.description = "Trained on my data" @@ -227,7 +229,7 @@ def test_update_tuned_model_basics(self, tuned_model, updates): ], ) def test_update_tuned_model_nested_fields(self, updates): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/my-pig-001", base_model="models/dance-monkey-007" ) @@ -250,8 +252,8 @@ def test_update_tuned_model_nested_fields(self, updates): @parameterized.named_parameters( ["name", "tunedModels/bipedal-pangolin-223"], [ - "glm.TunedModel", - glm.TunedModel(name="tunedModels/bipedal-pangolin-223"), + "protos.TunedModel", + protos.TunedModel(name="tunedModels/bipedal-pangolin-223"), ], [ "models.TunedModel", @@ -275,23 +277,23 @@ def test_decode_micros(self, time_str, micros): self.assertEqual(time["time"].microsecond, micros) def test_decode_tuned_model(self): - out_fields = glm.TunedModel( - state=glm.TunedModel.State.CREATING, + out_fields = protos.TunedModel( + state=protos.TunedModel.State.CREATING, create_time="2000-01-01T01:01:01.0Z", update_time="2001-01-01T01:01:01.0Z", - tuning_task=glm.TuningTask( - hyperparameters=glm.Hyperparameters( + tuning_task=protos.TuningTask( + hyperparameters=protos.Hyperparameters( batch_size=72, epoch_count=1, learning_rate=0.1 ), start_time="2002-01-01T01:01:01.0Z", complete_time="2003-01-01T01:01:01.0Z", snapshots=[ - glm.TuningSnapshot( + protos.TuningSnapshot( step=1, epoch=1, compute_time="2004-01-01T01:01:01.0Z", ), - glm.TuningSnapshot( + protos.TuningSnapshot( step=2, epoch=1, compute_time="2005-01-01T01:01:01.0Z", @@ -301,7 +303,7 @@ def test_decode_tuned_model(self): ) decoded = model_types.decode_tuned_model(out_fields) - self.assertEqual(decoded.state, glm.TunedModel.State.CREATING) + self.assertEqual(decoded.state, protos.TunedModel.State.CREATING) self.assertEqual(decoded.create_time.year, 2000) self.assertEqual(decoded.update_time.year, 2001) self.assertIsInstance(decoded.tuning_task.hyperparameters, model_types.Hyperparameters) @@ -314,10 +316,10 @@ def test_decode_tuned_model(self): self.assertEqual(decoded.tuning_task.snapshots[1]["compute_time"].year, 2005) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -341,7 +343,7 @@ def test_smoke_create_tuned_model(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], ) req = self.observed_requests[-1] @@ -351,10 +353,10 @@ def test_smoke_create_tuned_model(self): self.assertLen(req.tuned_model.tuning_task.training_data.examples.examples, 3) @parameterized.named_parameters( - ["simple", glm.TunedModel(base_model="models/swim-fish-000")], + ["simple", protos.TunedModel(base_model="models/swim-fish-000")], [ "nested", - glm.TunedModel( + protos.TunedModel( tuned_model_source={ "tuned_model": "tunedModels/hidden-fish-55", "base_model": "models/swim-fish-000", @@ -380,9 +382,9 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): @parameterized.named_parameters( [ - "glm", - glm.Dataset( - examples=glm.TuningExamples( + "protos", + protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -396,7 +398,7 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): [ ("a", "1"), {"text_input": "b", "output": "2"}, - glm.TuningExample({"text_input": "c", "output": "3"}), + protos.TuningExample({"text_input": "c", "output": "3"}), ], ], ["dict", {"text_input": ["a", "b", "c"], "output": ["1", "2", "3"]}], @@ -445,8 +447,8 @@ def test_create_tuned_model_on_tuned_model(self, tuned_source): def test_create_dataset(self, data, ik="text_input", ok="output"): ds = model_types.encode_tuning_data(data, input_key=ik, output_key=ok) - expect = glm.Dataset( - examples=glm.TuningExamples( + expect = protos.Dataset( + examples=protos.TuningExamples( examples=[ {"text_input": "a", "output": "1"}, {"text_input": "b", "output": "2"}, @@ -502,7 +504,7 @@ def test_update_tuned_model_called_with_request_options(self): self.client.update_tuned_model = unittest.mock.MagicMock() request = unittest.mock.ANY request_options = {"timeout": 120} - self.responses["get_tuned_model"] = glm.TunedModel(name="tunedModels/") + self.responses["get_tuned_model"] = protos.TunedModel(name="tunedModels/") try: models.update_tuned_model( @@ -534,7 +536,7 @@ def test_create_tuned_model_called_with_request_options(self): training_data=[ ("in", "out"), {"text_input": "in", "output": "out"}, - glm.TuningExample(text_input="in", output="out"), + protos.TuningExample(text_input="in", output="out"), ], request_options=request_options, ) diff --git a/tests/test_operations.py b/tests/test_operations.py index 80262db88..6529b77e5 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -16,7 +16,7 @@ from contextlib import redirect_stderr import io -import google.ai.generativelanguage as glm +from google.generativeai import protos import google.protobuf.any_pb2 import google.generativeai.operations as genai_operation @@ -41,7 +41,7 @@ def test_end_to_end(self): # `Any` takes a type name and a serialized proto. metadata = google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), + value=protos.CreateTunedModelMetadata(tuned_model=name)._pb.SerializeToString(), ) # Initially the `Operation` is not `done`, so it only gives a metadata. @@ -58,7 +58,7 @@ def test_end_to_end(self): metadata=metadata, response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -72,8 +72,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=lambda: print(f"cancel!"), - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. @@ -99,7 +99,7 @@ def gen_operations(): def make_metadata(completed_steps): return google.protobuf.any_pb2.Any( type_url=self.metadata_type, - value=glm.CreateTunedModelMetadata( + value=protos.CreateTunedModelMetadata( tuned_model=name, total_steps=total_steps, completed_steps=completed_steps, @@ -122,7 +122,7 @@ def make_metadata(completed_steps): metadata=make_metadata(total_steps), response=google.protobuf.any_pb2.Any( type_url=self.result_type, - value=glm.TunedModel(name=name)._pb.SerializeToString(), + value=protos.TunedModel(name=name)._pb.SerializeToString(), ), ) @@ -142,8 +142,8 @@ def refresh(*_, **__): operation=initial_pb, refresh=refresh, cancel=None, - result_type=glm.TunedModel, - metadata_type=glm.CreateTunedModelMetadata, + result_type=protos.TunedModel, + metadata_type=protos.CreateTunedModelMetadata, ) # Use our wrapper instead. diff --git a/tests/test_permission.py b/tests/test_permission.py index 55ad7a2f0..66b396977 100644 --- a/tests/test_permission.py +++ b/tests/test_permission.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -50,11 +50,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -63,24 +63,24 @@ def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -88,17 +88,17 @@ def create_permission( @add_client_method def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -106,16 +106,16 @@ def get_permission( @add_client_method def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) return [ - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ), - glm.Permission( + protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -125,10 +125,10 @@ def list_permissions( @add_client_method def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -136,16 +136,16 @@ def update_permission( @add_client_method def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() def test_create_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create(role="writer", grantee_type="everyone", email_address=None) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = retriever.create_corpus("demo-corpus") @@ -161,14 +161,14 @@ def test_delete_permission(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") perm.delete() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) def test_get_permission_with_full_name(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") fetch_perm = permission.get_permission(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_and_id_1(self): @@ -178,7 +178,7 @@ def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) def test_get_permission_with_resource_name_name_and_id_2(self): @@ -186,14 +186,14 @@ def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) def test_get_permission_with_resource_type(self): fetch_perm = permission.get_permission( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -257,14 +257,14 @@ def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) def test_update_permission_success(self): x = retriever.create_corpus("demo-corpus") perm = x.permissions.create("writer", "everyone") updated_perm = perm.update({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) def test_update_permission_failure_restricted_update_path(self): x = retriever.create_corpus("demo-corpus") @@ -275,12 +275,12 @@ def test_update_permission_failure_restricted_update_path(self): ) def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = x.permissions.transfer_ownership(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) def test_transfer_ownership_on_corpora(self): x = retriever.create_corpus("demo-corpus") diff --git a/tests/test_permission_async.py b/tests/test_permission_async.py index 165039122..ddc9c22a2 100644 --- a/tests/test_permission_async.py +++ b/tests/test_permission_async.py @@ -17,7 +17,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import permission @@ -49,11 +49,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -62,24 +62,24 @@ async def create_corpus( @add_client_method def get_tuned_model( - request: Optional[glm.GetTunedModelRequest] = None, + request: Optional[protos.GetTunedModelRequest] = None, *, name=None, **kwargs, - ) -> glm.TunedModel: + ) -> protos.TunedModel: if request is None: - request = glm.GetTunedModelRequest(name=name) - self.assertIsInstance(request, glm.GetTunedModelRequest) + request = protos.GetTunedModelRequest(name=name) + self.assertIsInstance(request, protos.GetTunedModelRequest) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @add_client_method async def create_permission( - request: glm.CreatePermissionRequest, - ) -> glm.Permission: + request: protos.CreatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -87,17 +87,17 @@ async def create_permission( @add_client_method async def delete_permission( - request: glm.DeletePermissionRequest, + request: protos.DeletePermissionRequest, ) -> None: self.observed_requests.append(request) return None @add_client_method async def get_permission( - request: glm.GetPermissionRequest, - ) -> glm.Permission: + request: protos.GetPermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -105,17 +105,17 @@ async def get_permission( @add_client_method async def list_permissions( - request: glm.ListPermissionsRequest, - ) -> glm.ListPermissionsResponse: + request: protos.ListPermissionsRequest, + ) -> protos.ListPermissionsResponse: self.observed_requests.append(request) async def results(): - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("writer"), grantee_type=permission_services.to_grantee_type("everyone"), ) - yield glm.Permission( + yield protos.Permission( name="corpora/demo-corpus/permissions/987654321", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -126,10 +126,10 @@ async def results(): @add_client_method async def update_permission( - request: glm.UpdatePermissionRequest, - ) -> glm.Permission: + request: protos.UpdatePermissionRequest, + ) -> protos.Permission: self.observed_requests.append(request) - return glm.Permission( + return protos.Permission( name="corpora/demo-corpus/permissions/123456789", role=permission_services.to_role("reader"), grantee_type=permission_services.to_grantee_type("everyone"), @@ -137,10 +137,10 @@ async def update_permission( @add_client_method async def transfer_ownership( - request: glm.TransferOwnershipRequest, - ) -> glm.TransferOwnershipResponse: + request: protos.TransferOwnershipRequest, + ) -> protos.TransferOwnershipResponse: self.observed_requests.append(request) - return glm.TransferOwnershipResponse() + return protos.TransferOwnershipResponse() async def test_create_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") @@ -148,7 +148,7 @@ async def test_create_permission_success(self): role="writer", grantee_type="everyone", email_address=None ) self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.CreatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.CreatePermissionRequest) async def test_create_permission_failure_email_set_when_grantee_type_is_everyone(self): x = await retriever.create_corpus_async("demo-corpus") @@ -168,14 +168,14 @@ async def test_delete_permission(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") await perm.delete_async() - self.assertIsInstance(self.observed_requests[-1], glm.DeletePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeletePermissionRequest) async def test_get_permission_with_full_name(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") fetch_perm = await permission.get_permission_async(name=perm.name) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_and_id_1(self): @@ -185,7 +185,7 @@ async def test_get_permission_with_resource_name_and_id_1(self): resource_name="corpora/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) self.assertEqual(fetch_perm, perm) async def test_get_permission_with_resource_name_name_and_id_2(self): @@ -193,14 +193,14 @@ async def test_get_permission_with_resource_name_name_and_id_2(self): resource_name="tunedModels/demo-corpus", permission_id=123456789 ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) async def test_get_permission_with_resource_type(self): fetch_perm = await permission.get_permission_async( resource_name="demo-model", permission_id=123456789, resource_type="tunedModels" ) self.assertIsInstance(fetch_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.GetPermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.GetPermissionRequest) @parameterized.named_parameters( dict( @@ -264,14 +264,14 @@ async def test_list_permission(self): self.assertEqual(perms[1].email_address, "_") for perm in perms: self.assertIsInstance(perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.ListPermissionsRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListPermissionsRequest) async def test_update_permission_success(self): x = await retriever.create_corpus_async("demo-corpus") perm = await x.permissions.create_async("writer", "everyone") updated_perm = await perm.update_async({"role": permission_services.to_role("reader")}) self.assertIsInstance(updated_perm, permission_services.Permission) - self.assertIsInstance(self.observed_requests[-1], glm.UpdatePermissionRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdatePermissionRequest) async def test_update_permission_failure_restricted_update_path(self): x = await retriever.create_corpus_async("demo-corpus") @@ -282,12 +282,12 @@ async def test_update_permission_failure_restricted_update_path(self): ) async def test_transfer_ownership(self): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/fake-pig-001", base_model="models/dance-monkey-007" ) x = models.get_tuned_model("tunedModels/fake-pig-001") response = await x.permissions.transfer_ownership_async(email_address="_") - self.assertIsInstance(self.observed_requests[-1], glm.TransferOwnershipRequest) + self.assertIsInstance(self.observed_requests[-1], protos.TransferOwnershipRequest) async def test_transfer_ownership_on_corpora(self): x = await retriever.create_corpus_async("demo-corpus") diff --git a/tests/test_protos.py b/tests/test_protos.py new file mode 100644 index 000000000..1b59b0c6e --- /dev/null +++ b/tests/test_protos.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib +import re + +from absl.testing import parameterized + +ROOT = pathlib.Path(__file__).parent.parent + + +class UnitTests(parameterized.TestCase): + def test_check_glm_imports(self): + for fpath in ROOT.rglob("*.py"): + if fpath.name == "build_docs.py": + continue + content = fpath.read_text() + for match in re.findall("glm\.\w+", content): + self.assertIn( + "Client", + match, + msg=f"Bad `glm.` usage, use `genai.protos` instead,\n in {fpath}", + ) diff --git a/tests/test_responder.py b/tests/test_responder.py index 4eb310815..c075fc65a 100644 --- a/tests/test_responder.py +++ b/tests/test_responder.py @@ -17,7 +17,7 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import responder import IPython.display import PIL.Image @@ -42,9 +42,9 @@ class UnitTests(parameterized.TestCase): [ "FunctionLibrary", responder.FunctionLibrary( - tools=glm.Tool( + tools=protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -56,7 +56,7 @@ class UnitTests(parameterized.TestCase): [ responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -64,11 +64,11 @@ class UnitTests(parameterized.TestCase): ], ], [ - "IterableTool-glm.Tool", + "IterableTool-protos.Tool", [ - glm.Tool( + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -93,7 +93,7 @@ class UnitTests(parameterized.TestCase): "IterableTool-IterableFD", [ [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -103,7 +103,7 @@ class UnitTests(parameterized.TestCase): [ "IterableTool-FD", [ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time.", ) @@ -113,17 +113,17 @@ class UnitTests(parameterized.TestCase): "Tool", responder.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] ), ], [ - "glm.Tool", - glm.Tool( + "protos.Tool", + protos.Tool( function_declarations=[ - glm.FunctionDeclaration( + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ) ] @@ -175,8 +175,8 @@ class UnitTests(parameterized.TestCase): ), ], [ - "glm.FD", - glm.FunctionDeclaration( + "protos.FD", + protos.FunctionDeclaration( name="datetime", description="Returns the current UTC date and time." ), ], @@ -216,32 +216,32 @@ def b(): self.assertLen(tools[0].function_declarations, 2) @parameterized.named_parameters( - ["int", int, glm.Schema(type=glm.Type.INTEGER)], - ["float", float, glm.Schema(type=glm.Type.NUMBER)], - ["str", str, glm.Schema(type=glm.Type.STRING)], + ["int", int, protos.Schema(type=protos.Type.INTEGER)], + ["float", float, protos.Schema(type=protos.Type.NUMBER)], + ["str", str, protos.Schema(type=protos.Type.STRING)], [ "list", list[str], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.STRING), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.STRING), ), ], [ "list-list-int", list[list[int]], - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema( - glm.Schema( - type=glm.Type.ARRAY, - items=glm.Schema(type=glm.Type.INTEGER), + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema( + protos.Schema( + type=protos.Type.ARRAY, + items=protos.Schema(type=protos.Type.INTEGER), ), ), ), ], - ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], - ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + ["dict", dict, protos.Schema(type=protos.Type.OBJECT)], + ["dict-str-any", dict[str, Any], protos.Schema(type=protos.Type.OBJECT)], ) def test_auto_schema(self, annotation, expected): def fun(a: annotation): diff --git a/tests/test_retriever.py b/tests/test_retriever.py index 910183789..bce9a402b 100644 --- a/tests/test_retriever.py +++ b/tests/test_retriever.py @@ -16,7 +16,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client @@ -42,11 +42,11 @@ def add_client_method(f): @add_client_method def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -55,11 +55,11 @@ def create_corpus( @add_client_method def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo_corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -68,11 +68,11 @@ def get_corpus( @add_client_method def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -81,18 +81,18 @@ def update_corpus( @add_client_method def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) return [ - glm.Corpus( + protos.Corpus( name="corpora/demo_corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Corpus( + protos.Corpus( name="corpora/demo-corpus-2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -102,15 +102,15 @@ def list_corpora( @add_client_method def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -124,18 +124,18 @@ def query_corpus( @add_client_method def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -144,11 +144,11 @@ def create_document( @add_client_method def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -157,11 +157,11 @@ def get_document( @add_client_method def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo_doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -170,18 +170,18 @@ def update_document( @add_client_method def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) return [ - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Document( + protos.Document( name="corpora/demo-corpus/documents/demo_doc_2", display_name="demo-doc-2", create_time="2000-01-01T01:01:01.123456Z", @@ -191,22 +191,22 @@ def list_documents( @add_client_method def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -220,11 +220,11 @@ def query_document( @add_client_method def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -233,19 +233,19 @@ def create_chunk( @add_client_method def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -256,11 +256,11 @@ def batch_create_chunks( @add_client_method def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -269,18 +269,18 @@ def get_chunk( @add_client_method def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) return [ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -290,17 +290,17 @@ def list_chunks( @add_client_method def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, custom_metadata=[ - glm.CustomMetadata( + protos.CustomMetadata( key="tags", - string_list_value=glm.StringList( + string_list_value=protos.StringList( values=["Google For Developers", "Project IDX", "Blog", "Announcement"] ), ) @@ -311,19 +311,19 @@ def update_chunk( @add_client_method def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/demo-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ def batch_update_chunks( @add_client_method def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -366,7 +366,7 @@ def test_get_corpus(self, name="demo-corpus"): def test_update_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") update_request = demo_corpus.update(updates={"display_name": "demo-corpus_1"}) - self.assertIsInstance(self.observed_requests[-1], glm.UpdateCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.UpdateCorpusRequest) self.assertEqual("demo-corpus_1", demo_corpus.display_name) def test_list_corpora(self): @@ -402,7 +402,7 @@ def test_delete_corpus(self): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") delete_request = retriever.delete_corpus(name="corpora/demo_corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) def test_create_document(self, display_name="demo-doc"): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -433,7 +433,7 @@ def test_delete_document(self): demo_document = demo_corpus.create_document(name="demo-doc") demo_doc2 = demo_corpus.create_document(name="demo-doc-2") delete_request = demo_corpus.delete_document(name="corpora/demo-corpus/documents/demo_doc") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) def test_list_documents(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -521,7 +521,7 @@ def test_batch_create_chunks(self, chunks): demo_corpus = retriever.create_corpus(name="demo-corpus") demo_document = demo_corpus.create_document(name="demo-doc") chunks = demo_document.batch_create_chunks(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -548,7 +548,7 @@ def test_list_chunks(self): ) list_req = list(demo_document.list_chunks()) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(list_req, 2) def test_update_chunk(self): @@ -615,7 +615,7 @@ def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = demo_document.batch_update_chunks(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -631,7 +631,7 @@ def test_delete_chunk(self): data="This is a demo chunk.", ) delete_request = demo_document.delete_chunk(name="demo-chunk") - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) def test_batch_delete_chunks(self): demo_corpus = retriever.create_corpus(name="demo-corpus") @@ -645,7 +645,7 @@ def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = demo_document.batch_delete_chunks(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) @parameterized.parameters( {"method": "create_corpus"}, diff --git a/tests/test_retriever_async.py b/tests/test_retriever_async.py index b764c23b2..bb0c862d1 100644 --- a/tests/test_retriever_async.py +++ b/tests/test_retriever_async.py @@ -19,7 +19,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import retriever from google.generativeai import client as client_lib @@ -44,11 +44,11 @@ def add_client_method(f): @add_client_method async def create_corpus( - request: glm.CreateCorpusRequest, + request: protos.CreateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -57,11 +57,11 @@ async def create_corpus( @add_client_method async def get_corpus( - request: glm.GetCorpusRequest, + request: protos.GetCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus", create_time="2000-01-01T01:01:01.123456Z", @@ -70,11 +70,11 @@ async def get_corpus( @add_client_method async def update_corpus( - request: glm.UpdateCorpusRequest, + request: protos.UpdateCorpusRequest, **kwargs, - ) -> glm.Corpus: + ) -> protos.Corpus: self.observed_requests.append(request) - return glm.Corpus( + return protos.Corpus( name="corpora/demo-corpus", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", @@ -83,19 +83,19 @@ async def update_corpus( @add_client_method async def list_corpora( - request: glm.ListCorporaRequest, + request: protos.ListCorporaRequest, **kwargs, - ) -> glm.ListCorporaResponse: + ) -> protos.ListCorporaResponse: self.observed_requests.append(request) async def results(): - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus-1", display_name="demo-corpus-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Corpus( + yield protos.Corpus( name="corpora/demo-corpus_2", display_name="demo-corpus-2", create_time="2000-01-01T01:01:01.123456Z", @@ -106,15 +106,15 @@ async def results(): @add_client_method async def query_corpus( - request: glm.QueryCorpusRequest, + request: protos.QueryCorpusRequest, **kwargs, - ) -> glm.QueryCorpusResponse: + ) -> protos.QueryCorpusResponse: self.observed_requests.append(request) - return glm.QueryCorpusResponse( + return protos.QueryCorpusResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -128,18 +128,18 @@ async def query_corpus( @add_client_method async def delete_corpus( - request: glm.DeleteCorpusRequest, + request: protos.DeleteCorpusRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def create_document( - request: glm.CreateDocumentRequest, + request: protos.CreateDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -148,11 +148,11 @@ async def create_document( @add_client_method async def get_document( - request: glm.GetDocumentRequest, + request: protos.GetDocumentRequest, **kwargs, ) -> retriever_service.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc", create_time="2000-01-01T01:01:01.123456Z", @@ -161,11 +161,11 @@ async def get_document( @add_client_method async def update_document( - request: glm.UpdateDocumentRequest, + request: protos.UpdateDocumentRequest, **kwargs, - ) -> glm.Document: + ) -> protos.Document: self.observed_requests.append(request) - return glm.Document( + return protos.Document( name="corpora/demo-corpus/documents/demo-doc", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", @@ -174,19 +174,19 @@ async def update_document( @add_client_method async def list_documents( - request: glm.ListDocumentsRequest, + request: protos.ListDocumentsRequest, **kwargs, - ) -> glm.ListDocumentsResponse: + ) -> protos.ListDocumentsResponse: self.observed_requests.append(request) async def results(): - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_1", display_name="demo-doc-1", create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Document( + yield protos.Document( name="corpora/demo-corpus/documents/dem-doc_2", display_name="demo-doc_2", create_time="2000-01-01T01:01:01.123456Z", @@ -197,22 +197,22 @@ async def results(): @add_client_method async def delete_document( - request: glm.DeleteDocumentRequest, + request: protos.DeleteDocumentRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def query_document( - request: glm.QueryDocumentRequest, + request: protos.QueryDocumentRequest, **kwargs, - ) -> glm.QueryDocumentResponse: + ) -> protos.QueryDocumentResponse: self.observed_requests.append(request) - return glm.QueryDocumentResponse( + return protos.QueryDocumentResponse( relevant_chunks=[ - glm.RelevantChunk( + protos.RelevantChunk( chunk_relevance_score=0.08, - chunk=glm.Chunk( + chunk=protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, custom_metadata=[], @@ -226,11 +226,11 @@ async def query_document( @add_client_method async def create_chunk( - request: glm.CreateChunkRequest, + request: protos.CreateChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -239,19 +239,19 @@ async def create_chunk( @add_client_method async def batch_create_chunks( - request: glm.BatchCreateChunksRequest, + request: protos.BatchCreateChunksRequest, **kwargs, - ) -> glm.BatchCreateChunksResponse: + ) -> protos.BatchCreateChunksResponse: self.observed_requests.append(request) - return glm.BatchCreateChunksResponse( + return protos.BatchCreateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/dc1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -262,11 +262,11 @@ async def batch_create_chunks( @add_client_method async def get_chunk( - request: glm.GetChunkRequest, + request: protos.GetChunkRequest, **kwargs, ) -> retriever_service.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -275,19 +275,19 @@ async def get_chunk( @add_client_method async def list_chunks( - request: glm.ListChunksRequest, + request: protos.ListChunksRequest, **kwargs, - ) -> glm.ListChunksResponse: + ) -> protos.ListChunksResponse: self.observed_requests.append(request) async def results(): - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is a demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ) - yield glm.Chunk( + yield protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -298,11 +298,11 @@ async def results(): @add_client_method async def update_chunk( - request: glm.UpdateChunkRequest, + request: protos.UpdateChunkRequest, **kwargs, - ) -> glm.Chunk: + ) -> protos.Chunk: self.observed_requests.append(request) - return glm.Chunk( + return protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated demo chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -311,19 +311,19 @@ async def update_chunk( @add_client_method async def batch_update_chunks( - request: glm.BatchUpdateChunksRequest, + request: protos.BatchUpdateChunksRequest, **kwargs, - ) -> glm.BatchUpdateChunksResponse: + ) -> protos.BatchUpdateChunksResponse: self.observed_requests.append(request) - return glm.BatchUpdateChunksResponse( + return protos.BatchUpdateChunksResponse( chunks=[ - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk", data={"string_value": "This is an updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", update_time="2000-01-01T01:01:01.123456Z", ), - glm.Chunk( + protos.Chunk( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk-1", data={"string_value": "This is another updated chunk."}, create_time="2000-01-01T01:01:01.123456Z", @@ -334,14 +334,14 @@ async def batch_update_chunks( @add_client_method async def delete_chunk( - request: glm.DeleteChunkRequest, + request: protos.DeleteChunkRequest, **kwargs, ) -> None: self.observed_requests.append(request) @add_client_method async def batch_delete_chunks( - request: glm.BatchDeleteChunksRequest, + request: protos.BatchDeleteChunksRequest, **kwargs, ) -> None: self.observed_requests.append(request) @@ -398,7 +398,7 @@ async def test_delete_corpus(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") delete_request = await retriever.delete_corpus_async(name="corpora/demo-corpus", force=True) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteCorpusRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteCorpusRequest) async def test_create_document(self, display_name="demo-doc"): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -425,7 +425,7 @@ async def test_delete_document(self): delete_request = await demo_corpus.delete_document_async( name="corpora/demo-corpus/documents/demo-doc" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteDocumentRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteDocumentRequest) async def test_list_documents(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -513,7 +513,7 @@ async def test_batch_create_chunks(self, chunks): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") demo_document = await demo_corpus.create_document_async(name="demo-doc") chunks = await demo_document.batch_create_chunks_async(chunks=chunks) - self.assertIsInstance(self.observed_requests[-1], glm.BatchCreateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchCreateChunksRequest) self.assertEqual("This is a demo chunk.", chunks[0].data.string_value) self.assertEqual("This is another demo chunk.", chunks[1].data.string_value) @@ -541,7 +541,7 @@ async def test_list_chunks(self): chunks = [] async for chunk in demo_document.list_chunks_async(): chunks.append(chunk) - self.assertIsInstance(self.observed_requests[-1], glm.ListChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.ListChunksRequest) self.assertLen(chunks, 2) async def test_update_chunk(self): @@ -597,7 +597,7 @@ async def test_batch_update_chunks_data_structures(self, updates): data="This is another demo chunk.", ) update_request = await demo_document.batch_update_chunks_async(chunks=updates) - self.assertIsInstance(self.observed_requests[-1], glm.BatchUpdateChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchUpdateChunksRequest) self.assertEqual( "This is an updated chunk.", update_request["chunks"][0]["data"]["string_value"] ) @@ -615,7 +615,7 @@ async def test_delete_chunk(self): delete_request = await demo_document.delete_chunk_async( name="corpora/demo-corpus/documents/dem-doc/chunks/demo-chunk" ) - self.assertIsInstance(self.observed_requests[-1], glm.DeleteChunkRequest) + self.assertIsInstance(self.observed_requests[-1], protos.DeleteChunkRequest) async def test_batch_delete_chunks(self): demo_corpus = await retriever.create_corpus_async(name="demo-corpus") @@ -629,7 +629,7 @@ async def test_batch_delete_chunks(self): data="This is another demo chunk.", ) delete_request = await demo_document.batch_delete_chunks_async(chunks=[x.name, y.name]) - self.assertIsInstance(self.observed_requests[-1], glm.BatchDeleteChunksRequest) + self.assertIsInstance(self.observed_requests[-1], protos.BatchDeleteChunksRequest) async def test_get_corpus_called_with_request_options(self): self.client.get_corpus = unittest.mock.AsyncMock() diff --git a/tests/test_safety.py b/tests/test_safety.py index f3efc4aca..2ac8aca46 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -15,26 +15,26 @@ from absl.testing import absltest from absl.testing import parameterized -import google.ai.generativelanguage as glm from google.generativeai.types import safety_types +from google.generativeai import protos class SafetyTests(parameterized.TestCase): """Tests are in order with the design doc.""" @parameterized.named_parameters( - ["block_threshold", glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], + ["block_threshold", protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE], ["block_threshold2", "medium"], ["block_threshold3", 2], ["dict", {"danger": "medium"}], ["dict2", {"danger": 2}], - ["dict3", {"danger": glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], + ["dict3", {"danger": protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE}], [ "list-dict", [ dict( - category=glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, - threshold=glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + category=protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ), ], ], @@ -48,8 +48,8 @@ class SafetyTests(parameterized.TestCase): def test_safety_overwrite(self, setting): setting = safety_types.to_easy_safety_dict(setting) self.assertEqual( - setting[glm.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], - glm.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, + setting[protos.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT], + protos.SafetySetting.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, ) diff --git a/tests/test_text.py b/tests/test_text.py index 5dcda93b9..795c3dfcd 100644 --- a/tests/test_text.py +++ b/tests/test_text.py @@ -18,7 +18,7 @@ import unittest import unittest.mock as mock -import google.ai.generativelanguage as glm +from google.generativeai import protos from google.generativeai import text as text_service from google.generativeai import client @@ -46,42 +46,42 @@ def add_client_method(f): @add_client_method def generate_text( - request: glm.GenerateTextRequest, + request: protos.GenerateTextRequest, **kwargs, - ) -> glm.GenerateTextResponse: + ) -> protos.GenerateTextResponse: self.observed_requests.append(request) return self.responses["generate_text"] @add_client_method def embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) return self.responses["embed_text"] @add_client_method def batch_embed_text( - request: glm.EmbedTextRequest, + request: protos.EmbedTextRequest, **kwargs, - ) -> glm.EmbedTextResponse: + ) -> protos.EmbedTextResponse: self.observed_requests.append(request) - return glm.BatchEmbedTextResponse( - embeddings=[glm.Embedding(value=[1, 2, 3])] * len(request.texts) + return protos.BatchEmbedTextResponse( + embeddings=[protos.Embedding(value=[1, 2, 3])] * len(request.texts) ) @add_client_method def count_text_tokens( - request: glm.CountTextTokensRequest, + request: protos.CountTextTokensRequest, **kwargs, - ) -> glm.CountTextTokensResponse: + ) -> protos.CountTextTokensResponse: self.observed_requests.append(request) return self.responses["count_text_tokens"] @add_client_method - def get_tuned_model(name) -> glm.TunedModel: - request = glm.GetTunedModelRequest(name=name) + def get_tuned_model(name) -> protos.TunedModel: + request = protos.GetTunedModelRequest(name=name) self.observed_requests.append(request) response = copy.copy(self.responses["get_tuned_model"]) return response @@ -93,7 +93,7 @@ def get_tuned_model(name) -> glm.TunedModel: ) def test_make_prompt(self, prompt): x = text_service._make_text_prompt(prompt) - self.assertIsInstance(x, glm.TextPrompt) + self.assertIsInstance(x, protos.TextPrompt) self.assertEqual("Hello how are", x.text) @parameterized.named_parameters( @@ -104,7 +104,7 @@ def test_make_prompt(self, prompt): def test_make_generate_text_request(self, prompt): x = text_service._make_generate_text_request(model="models/chat-bison-001", prompt=prompt) self.assertEqual("models/chat-bison-001", x.model) - self.assertIsInstance(x, glm.GenerateTextRequest) + self.assertIsInstance(x, protos.GenerateTextRequest) @parameterized.named_parameters( [ @@ -116,14 +116,16 @@ def test_make_generate_text_request(self, prompt): ] ) def test_generate_embeddings(self, model, text): - self.responses["embed_text"] = glm.EmbedTextResponse( - embedding=glm.Embedding(value=[1, 2, 3]) + self.responses["embed_text"] = protos.EmbedTextResponse( + embedding=protos.Embedding(value=[1, 2, 3]) ) emb = text_service.generate_embeddings(model=model, text=text) self.assertIsInstance(emb, dict) - self.assertEqual(self.observed_requests[-1], glm.EmbedTextRequest(model=model, text=text)) + self.assertEqual( + self.observed_requests[-1], protos.EmbedTextRequest(model=model, text=text) + ) self.assertIsInstance(emb["embedding"][0], float) @parameterized.named_parameters( @@ -191,11 +193,11 @@ def test_generate_embeddings_batch(self, model, text): ] ) def test_generate_response(self, *, prompt, **kwargs): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output=" road?"), - glm.TextCompletion(output=" bridge?"), - glm.TextCompletion(output=" river?"), + protos.TextCompletion(output=" road?"), + protos.TextCompletion(output=" bridge?"), + protos.TextCompletion(output=" river?"), ] ) @@ -203,8 +205,8 @@ def test_generate_response(self, *, prompt, **kwargs): self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( - model="models/text-bison-001", prompt=glm.TextPrompt(text=prompt), **kwargs + protos.GenerateTextRequest( + model="models/text-bison-001", prompt=protos.TextPrompt(text=prompt), **kwargs ), ) @@ -220,20 +222,20 @@ def test_generate_response(self, *, prompt, **kwargs): ) def test_stop_string(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="Hello world?"), - glm.TextCompletion(output="Hell!"), - glm.TextCompletion(output="I'm going to stop"), + protos.TextCompletion(output="Hello world?"), + protos.TextCompletion(output="Hell!"), + protos.TextCompletion(output="I'm going to stop"), ] ) complete = text_service.generate_text(prompt="Hello", stop_sequences="stop") self.assertEqual( self.observed_requests[-1], - glm.GenerateTextRequest( + protos.GenerateTextRequest( model="models/text-bison-001", - prompt=glm.TextPrompt(text="Hello"), + prompt=protos.TextPrompt(text="Hello"), stop_sequences=["stop"], ), ) @@ -282,9 +284,9 @@ def test_stop_string(self): ] ) def test_safety_settings(self, safety_settings): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ - glm.TextCompletion(output="No"), + protos.TextCompletion(output="No"), ] ) # This test really just checks that the safety_settings get converted to a proto. @@ -298,7 +300,7 @@ def test_safety_settings(self, safety_settings): ) def test_filters(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], filters=[ { @@ -313,7 +315,7 @@ def test_filters(self): self.assertEqual(response.filters[0]["reason"], palm_safety_types.BlockedReason.SAFETY) def test_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[{"output": "hello"}], safety_feedback=[ { @@ -341,7 +343,7 @@ def test_safety_feedback(self): self.assertIsInstance( response.safety_feedback[0]["setting"]["category"], - glm.HarmCategory, + protos.HarmCategory, ) self.assertEqual( response.safety_feedback[0]["setting"]["category"], @@ -349,7 +351,7 @@ def test_safety_feedback(self): ) def test_candidate_safety_feedback(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "hello", @@ -370,7 +372,7 @@ def test_candidate_safety_feedback(self): result = text_service.generate_text(prompt="Write a story from the ER.") self.assertIsInstance( result.candidates[0]["safety_ratings"][0]["category"], - glm.HarmCategory, + protos.HarmCategory, ) self.assertEqual( result.candidates[0]["safety_ratings"][0]["category"], @@ -387,7 +389,7 @@ def test_candidate_safety_feedback(self): ) def test_candidate_citations(self): - self.responses["generate_text"] = glm.GenerateTextResponse( + self.responses["generate_text"] = protos.GenerateTextResponse( candidates=[ { "output": "Hello Google!", @@ -434,21 +436,21 @@ def test_candidate_citations(self): ), ), dict( - testcase_name="glm_model", - model=glm.Model( + testcase_name="protos.model", + model=protos.Model( name="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model", - model=glm.TunedModel( + testcase_name="protos.tuned_model", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001", ), ), dict( - testcase_name="glm_tuned_model_nested", - model=glm.TunedModel( + testcase_name="protos.tuned_model_nested", + model=protos.TunedModel( name="tunedModels/bipedal-pangolin-002", tuned_model_source={ "tuned_model": "tunedModels/bipedal-pangolin-002", @@ -459,10 +461,10 @@ def test_candidate_citations(self): ] ) def test_count_message_tokens(self, model): - self.responses["get_tuned_model"] = glm.TunedModel( + self.responses["get_tuned_model"] = protos.TunedModel( name="tunedModels/bipedal-pangolin-001", base_model="models/text-bison-001" ) - self.responses["count_text_tokens"] = glm.CountTextTokensResponse(token_count=7) + self.responses["count_text_tokens"] = protos.CountTextTokensResponse(token_count=7) response = text_service.count_text_tokens(model, "Tell me a story about a magic backpack.") self.assertEqual({"token_count": 7}, response) @@ -472,7 +474,7 @@ def test_count_message_tokens(self, model): self.assertLen(self.observed_requests, 2) self.assertEqual( self.observed_requests[0], - glm.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), + protos.GetTunedModelRequest(name="tunedModels/bipedal-pangolin-001"), ) def test_count_text_tokens_called_with_request_options(self):