Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use google.ai.generativelanguage 0.2.0 #7

Merged
merged 15 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ This "editable" mode lets you edit the source without needing to reinstall the p
Use the builtin unittest package:

```
python -m unittest
python -m unittest discover --pattern '*test*.py'
```

Or to debug, use:
Expand Down
14 changes: 7 additions & 7 deletions google/generativeai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def configure(
# but that seems rare. Users that need it can just switch to the low level API.
transport: Union[str, None] = None,
client_options: Union[client_options_lib.ClientOptions, dict, None] = None,
client_info: Optional[gapic_v1.client_info.ClientInfo] = None
client_info: Optional[gapic_v1.client_info.ClientInfo] = None,
):
"""Captures default client configuration.

Expand Down Expand Up @@ -86,13 +86,13 @@ def configure(

user_agent = f"{USER_AGENT}/{version.__version__}"
if client_info:
# Be respectful of any existing agent setting.
if client_info.user_agent:
client_info.user_agent += f" {user_agent}"
else:
client_info.user_agent = user_agent
# Be respectful of any existing agent setting.
if client_info.user_agent:
client_info.user_agent += f" {user_agent}"
else:
client_info.user_agent = user_agent
else:
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)
client_info = gapic_v1.client_info.ClientInfo(user_agent=user_agent)

new_default_client_config = {
"credentials": credentials,
Expand Down
29 changes: 25 additions & 4 deletions google/generativeai/discuss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.generativeai.client import get_default_discuss_async_client
from google.generativeai.types import discuss_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types


def _make_message(content: discuss_types.MessageOptions) -> glm.Message:
Expand Down Expand Up @@ -389,8 +390,11 @@ def __init__(self, **kwargs):

@property
@set_doc(discuss_types.ChatResponse.last.__doc__)
def last(self) -> str:
return self.messages[-1]["content"]
def last(self) -> Optional[str]:
if self.messages[-1]:
return self.messages[-1]["content"]
else:
return None

@last.setter
def last(self, message: discuss_types.MessageOptions):
Expand All @@ -405,8 +409,16 @@ def reply(
raise TypeError(
f"reply can't be called on an async client, use reply_async instead."
)
if self.last is None:
raise ValueError(
"The last response from the model did not return any candidates.\n"
"Check the `.filters` attribute to see why the responses were filtered:\n"
f"{self.filters}"
)

request = self.to_dict()
request.pop("candidates")
request.pop("filters", None)
request["messages"] = list(request["messages"])
request["messages"].append(_make_message(message))
request = _make_generate_message_request(**request)
Expand All @@ -422,6 +434,7 @@ async def reply_async(
)
request = self.to_dict()
request.pop("candidates")
request.pop("filters")
request["messages"] = list(request["messages"])
request["messages"].append(_make_message(message))
request = _make_generate_message_request(**request)
Expand All @@ -440,12 +453,20 @@ def _build_chat_response(
request["messages"] = prompt["messages"]

response = type(response).to_dict(response)
request["messages"].append(response["candidates"][0])
response.pop("messages")

response["filters"] = safety_types.convert_filters_to_enums(response["filters"])

if response["candidates"]:
last = response["candidates"][0]
else:
last = None
request["messages"].append(last)
request.setdefault("temperature", None)
request.setdefault("candidate_count", None)

return ChatResponse(
_client=client, candidates=response["candidates"], **request
_client=client, **response, **request
) # pytype: disable=missing-parameter


Expand Down
21 changes: 21 additions & 0 deletions google/generativeai/docstring_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def strip_oneof(docstring):
lines = docstring.splitlines()
lines = [line for line in lines if ".. _oneof:" not in line]
lines = [line for line in lines if "This field is a member of `oneof`_" not in line]
return "\n".join(lines)
22 changes: 22 additions & 0 deletions google/generativeai/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.generativeai.client import get_default_text_client
from google.generativeai.types import text_types
from google.generativeai.types import model_types
from google.generativeai.types import safety_types


def _make_text_prompt(prompt: Union[str, dict[str, str]]) -> glm.TextPrompt:
Expand All @@ -44,6 +45,7 @@ def _make_generate_text_request(
max_output_tokens: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
safety_settings: Optional[List[safety_types.SafetySettingDict]] = None,
stop_sequences: Union[str, Iterable[str]] = None,
) -> glm.GenerateTextRequest:
model = model_types.make_model_name(model)
Expand All @@ -61,6 +63,7 @@ def _make_generate_text_request(
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
safety_settings=safety_settings,
stop_sequences=stop_sequences,
)

Expand All @@ -74,6 +77,7 @@ def generate_text(
max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None,
top_k: Optional[float] = None,
safety_settings: Optional[Iterable[safety.SafetySettingDict]] = None,
stop_sequences: Union[str, Iterable[str]] = None,
client: Optional[glm.TextServiceClient] = None,
) -> text_types.Completion:
Expand Down Expand Up @@ -103,6 +107,15 @@ def generate_text(
For example, if the sorted probabilities are
`[0.5, 0.2, 0.1, 0.1, 0.05, 0.05]` a `top_p` of `0.8` will sample
as `[0.625, 0.25, 0.125, 0, 0, 0].
safety_settings: A list of unique `types.SafetySetting` instances for blocking unsafe content.
These will be enforced on the `prompt` and
`candidates`. There should not be more than one
setting for each `types.SafetyCategory` type. The API will block any prompts and
responses that fail to meet the thresholds set by these settings. This list
overrides the default settings for each `SafetyCategory` specified in the
safety_settings. If there is no `types.SafetySetting` for a given
`SafetyCategory` provided in the list, the API will use the default safety
setting for that category.
stop_sequences: A set of up to 5 character sequences that will stop output generation.
If specified, the API will stop at the first appearance of a stop
sequence. The stop sequence will not be included as part of the response.
Expand All @@ -119,6 +132,7 @@ def generate_text(
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
safety_settings=safety_settings,
stop_sequences=stop_sequences,
)

Expand All @@ -145,6 +159,14 @@ def _generate_response(
response = client.generate_text(request)
response = type(response).to_dict(response)

response["filters"] = safety_types.convert_filters_to_enums(response["filters"])
response["safety_feedback"] = safety_types.convert_safety_feedback_to_enums(
response["safety_feedback"]
)
response["candidates"] = safety_types.convert_candidate_enums(
response["candidates"]
)

return Completion(_client=client, **response)


Expand Down
5 changes: 5 additions & 0 deletions google/generativeai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from google.generativeai.types.discuss_types import *
from google.generativeai.types.model_types import *
from google.generativeai.types.text_types import *
from google.generativeai.types.citation_types import *
from google.generativeai.types.safety_types import *

del discuss_types
del model_types
del text_types
del citation_types
del safety_types
39 changes: 39 additions & 0 deletions google/generativeai/types/citation_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, List

from google.ai import generativelanguage as glm
from google.generativeai import docstring_utils
from typing import TypedDict

__all__ = [
"CitationMetadataDict",
"CitationSourceDict",
]


class CitationSourceDict(TypedDict):
start_index: Optional[int]
end_index: Optional[int]
uri: Optional[str]
license: Optional[str]

__doc__ = docstring_utils.strip_oneof(glm.CitationSource.__doc__)


class CitationMetadataDict(TypedDict):
citation_sources = Optional[List[CitationSourceDict]]

__doc__ = docstring_utils.strip_oneof(glm.CitationMetadata.__doc__)
20 changes: 16 additions & 4 deletions google/generativeai/types/discuss_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from typing import Any, Dict, TypedDict, Union, Iterable, Optional, Tuple, List

import google.ai.generativelanguage as glm
from google.generativeai.types import safety_types
from google.generativeai.types import citation_types

__all__ = [
"MessageDict",
Expand All @@ -35,11 +37,12 @@
]


class MessageDict(TypedDict, total=False):
class MessageDict(TypedDict):
"""A dict representation of a `glm.Message`."""

author: str
content: str
citation_metadata: Optional[citation_types.CitationMetadataDict]


MessageOptions = Union[str, MessageDict, glm.Message]
Expand Down Expand Up @@ -129,7 +132,14 @@ class ChatResponse(abc.ABC):
Note: The `temperature` field affects the variability of the responses. Low
temperatures will return few candidates. Setting `temperature=0` is deterministic,
so it will only ever return one candidate.

filters: This indicates which `types.SafetyCategory`(s) blocked a
candidate from this response, the lowest `types.HarmProbability`
that triggered a block, and the `types.HarmThreshold` setting for that category.
This indicates the smallest change to the `types.SafetySettings` that would be
necessary to unblock at least 1 response.

The blocking is configured by the `types.SafetySettings` in the request (or the
default `types.SafetySettings` of the API).
messages: Contains all the `messages` that were passed when the model was called,
plus the top `candidate` message.
model: The model name.
Expand All @@ -140,21 +150,23 @@ class ChatResponse(abc.ABC):
candidate_count: The **maximum** number of generated response messages to return.
top_k: The maximum number of tokens to consider when sampling.
top_p: The maximum cumulative probability of tokens to consider when sampling.

"""

model: str
context: str
examples: List[ExampleDict]
messages: List[MessageDict]
messages: List[Optional[MessageDict]]
temperature: Optional[float]
candidate_count: Optional[int]
candidates: List[MessageDict]
top_p: Optional[float] = None
top_k: Optional[float] = None
filters: List[safety_types.ContentFilterDict]

@property
@abc.abstractmethod
def last(self) -> str:
def last(self) -> Optional[str]:
"""A settable property that provides simple access to the last response string

A shortcut for `response.messages[0]['content']`.
Expand Down
Loading