Skip to content

Commit

Permalink
feat: GenAI - Switched the GA version of the generative_models clas…
Browse files Browse the repository at this point in the history
…ses to use the v1 service APIs instead of v1beta1

PiperOrigin-RevId: 675454152
  • Loading branch information
Ark-kun authored and copybara-github committed Sep 17, 2024
1 parent f78b953 commit 66d84af
Show file tree
Hide file tree
Showing 3 changed files with 329 additions and 132 deletions.
190 changes: 90 additions & 100 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@

import vertexai
from google.cloud.aiplatform import initializer
from google.cloud.aiplatform_v1 import types as types_v1
from google.cloud.aiplatform_v1.services import (
prediction_service as prediction_service_v1,
)
from google.cloud.aiplatform_v1beta1 import types as types_v1beta1
from vertexai import generative_models
from vertexai.preview import (
generative_models as preview_generative_models,
Expand Down Expand Up @@ -326,6 +331,72 @@ def mock_stream_generate_content(
yield blocked_chunk


def mock_generate_content_v1(
self,
request: types_v1.GenerateContentRequest,
*,
model: Optional[str] = None,
contents: Optional[MutableSequence[types_v1.Content]] = None,
) -> types_v1.GenerateContentResponse:
request_v1beta1 = types_v1beta1.GenerateContentRequest.deserialize(
type(request).serialize(request)
)
response_v1beta1 = mock_generate_content(
self=self,
request=request_v1beta1,
)
response_v1 = types_v1.GenerateContentResponse.deserialize(
type(response_v1beta1).serialize(response_v1beta1)
)
return response_v1


def mock_stream_generate_content_v1(
self,
request: types_v1.GenerateContentRequest,
*,
model: Optional[str] = None,
contents: Optional[MutableSequence[types_v1.Content]] = None,
) -> Iterable[types_v1.GenerateContentResponse]:
request_v1beta1 = types_v1beta1.GenerateContentRequest.deserialize(
type(request).serialize(request)
)
for response_v1beta1 in mock_stream_generate_content(
self=self,
request=request_v1beta1,
):
response_v1 = types_v1.GenerateContentResponse.deserialize(
type(response_v1beta1).serialize(response_v1beta1)
)
yield response_v1


def patch_genai_services(func: callable):
"""Patches GenAI services (v1 and v1beta1, streaming and non-streaming)."""

func = mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)(func)
func = mock.patch.object(
target=prediction_service_v1.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content_v1,
)(func)
func = mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
)(func)
func = mock.patch.object(
target=prediction_service_v1.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content_v1,
)(func)
return func


@pytest.fixture
def mock_get_cached_content_fixture():
"""Mocks GenAiCacheServiceClient.get_cached_content()."""
Expand Down Expand Up @@ -376,11 +447,6 @@ def setup_method(self):
def teardown_method(self):
initializer.global_pool.shutdown(wait=True)

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -489,11 +555,7 @@ def test_generative_model_from_cached_content_with_resource_name(
== "cached-content-id-in-from-cached-content-test"
)

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -601,11 +663,7 @@ def test_generate_content_with_cached_content(

assert response.text == "response to " + cached_content.resource_name

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand All @@ -616,11 +674,7 @@ def test_generate_content_streaming(self, generative_models: generative_models):
for chunk in stream:
assert chunk.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -668,11 +722,7 @@ def test_generate_content_response_accessor_errors(
assert e.match("no text")
assert e.match("function_call")

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand All @@ -685,11 +735,7 @@ def test_chat_send_message(self, generative_models: generative_models):
response2 = chat.send_message("Is sky blue on other planets?")
assert response2.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="stream_generate_content",
new=mock_stream_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand All @@ -704,11 +750,7 @@ def test_chat_send_message_streaming(self, generative_models: generative_models)
for chunk in stream2:
assert chunk.candidates

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand All @@ -727,11 +769,7 @@ def test_chat_send_message_response_validation_errors(
# Checking that history did not get updated
assert len(chat.history) == 2

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand All @@ -754,11 +792,7 @@ def test_chat_send_message_response_blocked_errors(
# Checking that history did not get updated
assert len(chat.history) == 2

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand All @@ -775,11 +809,7 @@ def test_chat_send_message_response_candidate_blocked_error(
# Checking that history did not get updated
assert not chat.history

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -808,11 +838,7 @@ def test_finish_reason_max_tokens_in_generate_content_and_send_message(
# Verify that history did not get updated
assert not chat.history

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -861,11 +887,7 @@ def test_chat_function_calling(self, generative_models: generative_models):
assert "nice" in response2.text
assert not response2.candidates[0].function_calls

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -922,11 +944,7 @@ def test_chat_forced_function_calling(self, generative_models: generative_models
assert "nice" in response2.text
assert not response2.candidates[0].function_calls

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
@pytest.mark.parametrize(
"generative_models",
[generative_models, preview_generative_models],
Expand Down Expand Up @@ -982,11 +1000,7 @@ def test_conversion_methods(self, generative_models: generative_models):
# Checking that the enums are serialized as strings, not integers.
assert response.to_dict()["candidates"][0]["finish_reason"] == "STOP"

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_generate_content_grounding_google_search_retriever_preview(self):
model = preview_generative_models.GenerativeModel("gemini-pro")
google_search_retriever_tool = (
Expand All @@ -999,11 +1013,7 @@ def test_generate_content_grounding_google_search_retriever_preview(self):
)
assert response.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_generate_content_grounding_google_search_retriever(self):
model = generative_models.GenerativeModel("gemini-pro")
google_search_retriever_tool = (
Expand All @@ -1016,11 +1026,7 @@ def test_generate_content_grounding_google_search_retriever(self):
)
assert response.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_generate_content_grounding_vertex_ai_search_retriever(self):
model = preview_generative_models.GenerativeModel("gemini-pro")
vertex_ai_search_retriever_tool = preview_generative_models.Tool.from_retrieval(
Expand All @@ -1035,11 +1041,7 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self):
)
assert response.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_location(
self,
):
Expand All @@ -1058,11 +1060,7 @@ def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_
)
assert response.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_generate_content_vertex_rag_retriever(self):
model = preview_generative_models.GenerativeModel("gemini-pro")
rag_resources = [
Expand All @@ -1085,11 +1083,7 @@ def test_generate_content_vertex_rag_retriever(self):
)
assert response.text

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_chat_automatic_function_calling_with_function_returning_dict(self):
generative_models = preview_generative_models
get_current_weather_func = generative_models.FunctionDeclaration.from_func(
Expand Down Expand Up @@ -1124,11 +1118,7 @@ def test_chat_automatic_function_calling_with_function_returning_dict(self):
chat2.send_message("What is the weather like in Boston?")
assert err.match("Exceeded the maximum")

@mock.patch.object(
target=prediction_service.PredictionServiceClient,
attribute="generate_content",
new=mock_generate_content,
)
@patch_genai_services
def test_chat_automatic_function_calling_with_function_returning_value(self):
# Define a new function that returns a value instead of a dict.
def get_current_weather(location: str):
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/vertexai/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
# TODO(b/360932655): Use mock_generate_content from test_generative_models
from vertexai.preview import rag
from vertexai.generative_models._generative_models import (
prediction_service,
gapic_prediction_service_types,
gapic_content_types,
gapic_tool_types,
prediction_service_v1 as prediction_service,
types_v1 as gapic_prediction_service_types,
types_v1 as gapic_content_types,
types_v1 as gapic_tool_types,
)


Expand Down
Loading

0 comments on commit 66d84af

Please sign in to comment.